From a142c365dd6cb0c019d5031d612cd6da98e8562f Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Tue, 23 Dec 2025 14:13:32 +0100 Subject: [PATCH 001/111] use syslink for wall-x readme (#2708) * use syslink for wall-x readme * remove whitespace --- docs/source/policy_walloss_README.md | 35 ++++++++++++++++++++++++++ src/lerobot/policies/wall_x/README.md | 36 +-------------------------- 2 files changed, 36 insertions(+), 35 deletions(-) create mode 100644 docs/source/policy_walloss_README.md mode change 100644 => 120000 src/lerobot/policies/wall_x/README.md diff --git a/docs/source/policy_walloss_README.md b/docs/source/policy_walloss_README.md new file mode 100644 index 000000000..78548bd8d --- /dev/null +++ b/docs/source/policy_walloss_README.md @@ -0,0 +1,35 @@ +# WALL-OSS + +This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Language-Action model for cross-embodiment robotic control based on Qwen2.5-VL with flow matching/FAST action prediction. + +--- + +## Model Overview + +| Feature | Description | +| ------------------ | ----------------------------------------------------- | --- | +| Base Model | Qwen2.5-VL (Vision-Language Model) | +| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) | +| Architecture | Mixture of Experts (MoE) with action-specific routing | | +| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception | + +--- + +## Citation + +If you use this work, please cite: + +```bibtex +@article{zhai2025igniting, + title = {Igniting VLMs Toward the Embodied Space}, + author = {Zhai, Andy and Liu, Brae and Fang, Bruno and Cai, Chalse and Ma, Ellie and Yin, Ethan and Wang, Hao and Zhou, Hugo and Wang, James and Shi, Lights and Liang, Lucy and Wang, Make and Wang, Qian and Gan, Roy and Yu, Ryan and Li, Shalfun and Liu, Starrick and Chen, Sylas and Chen, Vincent and Xu, Zach}, + journal = {arXiv preprint arXiv:2509.11766}, + year = {2025} +} +``` + +--- + +## License + +This port follows the **Apache 2.0 License**. diff --git a/src/lerobot/policies/wall_x/README.md b/src/lerobot/policies/wall_x/README.md deleted file mode 100644 index 78548bd8d..000000000 --- a/src/lerobot/policies/wall_x/README.md +++ /dev/null @@ -1,35 +0,0 @@ -# WALL-OSS - -This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Language-Action model for cross-embodiment robotic control based on Qwen2.5-VL with flow matching/FAST action prediction. - ---- - -## Model Overview - -| Feature | Description | -| ------------------ | ----------------------------------------------------- | --- | -| Base Model | Qwen2.5-VL (Vision-Language Model) | -| Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) | -| Architecture | Mixture of Experts (MoE) with action-specific routing | | -| Multi-Modal Inputs | Vision (images/videos), Language, Proprioception | - ---- - -## Citation - -If you use this work, please cite: - -```bibtex -@article{zhai2025igniting, - title = {Igniting VLMs Toward the Embodied Space}, - author = {Zhai, Andy and Liu, Brae and Fang, Bruno and Cai, Chalse and Ma, Ellie and Yin, Ethan and Wang, Hao and Zhou, Hugo and Wang, James and Shi, Lights and Liang, Lucy and Wang, Make and Wang, Qian and Gan, Roy and Yu, Ryan and Li, Shalfun and Liu, Starrick and Chen, Sylas and Chen, Vincent and Xu, Zach}, - journal = {arXiv preprint arXiv:2509.11766}, - year = {2025} -} -``` - ---- - -## License - -This port follows the **Apache 2.0 License**. diff --git a/src/lerobot/policies/wall_x/README.md b/src/lerobot/policies/wall_x/README.md new file mode 120000 index 000000000..4426967de --- /dev/null +++ b/src/lerobot/policies/wall_x/README.md @@ -0,0 +1 @@ +../../../../docs/source/policy_walloss_README.md \ No newline at end of file From ff271e8b512d2164e40c47fe730d3de092489418 Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Tue, 23 Dec 2025 23:58:34 +0100 Subject: [PATCH 002/111] pi fixes for dependencies (#2706) * pi fixes for dependencies * add walls sarm conflict * also add conflicts for pi * fix(ci): use --extra all instead of --all-extras + --no-extra --------- Co-authored-by: Steven Palma --- .github/workflows/full_tests.yml | 2 +- .github/workflows/unbound_deps_tests.yml | 2 +- pyproject.toml | 39 +++++++++++++++++++++++- src/lerobot/policies/pi0/modeling_pi0.py | 7 +++-- 4 files changed, 44 insertions(+), 6 deletions(-) diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml index ad222b04f..7962d4a3a 100644 --- a/.github/workflows/full_tests.yml +++ b/.github/workflows/full_tests.yml @@ -85,7 +85,7 @@ jobs: python-version: ${{ env.PYTHON_VERSION }} - name: Install lerobot with all extras - run: uv sync --all-extras --no-extra groot --no-extra wallx # TODO(Steven): Make flash-attn optional + run: uv sync --extra all # TODO(Steven): Make flash-attn optional - name: Run pytest (all extras) run: uv run pytest tests -vv --maxfail=10 diff --git a/.github/workflows/unbound_deps_tests.yml b/.github/workflows/unbound_deps_tests.yml index 95562d0dd..bf93dcda6 100644 --- a/.github/workflows/unbound_deps_tests.yml +++ b/.github/workflows/unbound_deps_tests.yml @@ -78,7 +78,7 @@ jobs: echo "Dependencies unbound:" && cat pyproject.toml - name: Install lerobot with all extras - run: uv sync --all-extras --no-extra groot --no-extra wallx # TODO(Steven): Make flash-attn optional + run: uv sync --extra all # TODO(Steven): Make flash-attn optional - name: Run pytest (all extras) run: uv run pytest tests -vv diff --git a/pyproject.toml b/pyproject.toml index 48e071d32..61b802bc5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -168,7 +168,7 @@ all = [ "lerobot[kinematics]", "lerobot[intelrealsense]", # "lerobot[wallx]", - "lerobot[pi]", + # "lerobot[pi]", TODO(Pepijn): Update pi to transformers v5 "lerobot[smolvla]", # "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn "lerobot[xvla]", @@ -405,6 +405,10 @@ conflicts = [ { extra = "wallx" }, { extra = "xvla" }, ], + [ + { extra = "wallx" }, + { extra = "sarm" }, + ], [ { extra = "wallx" }, { extra = "hilserl" }, @@ -417,4 +421,37 @@ conflicts = [ { extra = "wallx" }, { extra = "all" }, ], + # pi uses custom branch which conflicts with transformers-dep + [ + { extra = "pi" }, + { extra = "transformers-dep" }, + ], + [ + { extra = "pi" }, + { extra = "smolvla" }, + ], + [ + { extra = "pi" }, + { extra = "groot" }, + ], + [ + { extra = "pi" }, + { extra = "xvla" }, + ], + [ + { extra = "pi" }, + { extra = "sarm" }, + ], + [ + { extra = "pi" }, + { extra = "hilserl" }, + ], + [ + { extra = "pi" }, + { extra = "libero" }, + ], + [ + { extra = "pi" }, + { extra = "all" }, + ], ] diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 0d9c77e00..e4970dcf1 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -93,10 +93,11 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy) - alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device) - beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device) + # Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU + alpha_t = torch.tensor(alpha, dtype=torch.float32) + beta_t = torch.tensor(beta, dtype=torch.float32) dist = torch.distributions.Beta(alpha_t, beta_t) - return dist.sample((bsize,)) + return dist.sample((bsize,)).to(device) def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy) From 2f238fce15ec226998c29c4369a27de6a5c7092d Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 24 Dec 2025 00:40:56 +0100 Subject: [PATCH 003/111] feat(ci): adds release versioning to docs (#2709) * feat(ci): adds release versioning to docs * chore(ci): remove TODO --- .github/workflows/documentation.yml | 7 +++++-- .github/workflows/release.yml | 1 - 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index 3007578fc..48a10e4bc 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -33,6 +33,9 @@ on: paths: - "docs/**" + release: + types: [published] + # Ensures that only the latest commit for a PR or branch is built, canceling older runs. concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -43,7 +46,7 @@ jobs: build_main_docs: name: Build Main Docs if: > - (github.event_name == 'push' || github.event_name == 'workflow_dispatch') && + (github.event_name == 'push' || github.event_name == 'workflow_dispatch' || github.event_name == 'release') && github.repository == 'huggingface/lerobot' permissions: contents: read @@ -51,7 +54,7 @@ jobs: with: commit_sha: ${{ github.sha }} package: lerobot - additional_args: --not_python_module + additional_args: --not_python_module ${{ github.event_name == 'release' && format('--version {0}', github.event.release.tag_name) || '' }} secrets: token: ${{ secrets.HUGGINGFACE_PUSH }} hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 7b159dd17..c61307d6c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -177,4 +177,3 @@ jobs: # TODO(Steven): Publish draft/pre-release and to test pypi weekly # TODO(Steven): Separate build and publish job -# TODO(Steven): Tag documentation with the same version as the package From 20c22a27996576b7471fa1721933a3146af39849 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 24 Dec 2025 02:03:12 +0100 Subject: [PATCH 004/111] chore(ci): make keyword matching more conservative (#2711) --- .github/workflows/issue_labeler.yml | 44 +++++++++++------------------ 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/.github/workflows/issue_labeler.yml b/.github/workflows/issue_labeler.yml index 27ca2b5f9..438184e3f 100644 --- a/.github/workflows/issue_labeler.yml +++ b/.github/workflows/issue_labeler.yml @@ -42,38 +42,26 @@ jobs: // Keyword Heuristics - // Domain Specific - if (matches(/\b(bug|error|issue|fault|crash|exception)\b/i)) labelsToAdd.add('bug'); - if (matches(/\b(feature|enhancement|improvement|support|implement|proposal)\b/i)) labelsToAdd.add('enhancement'); - if (matches(/\b(question|help|how to||clarify|explain|unclear)\b/i)) labelsToAdd.add('question'); - if (matches(/\b(maintenance|documentation|docs|readme|tutorial|guide|wiki)\b/i)) labelsToAdd.add('documentation'); - if (matches(/\b(example|script|sample|demo|notebook)s?\b/i)) labelsToAdd.add('examples'); + if (matches(/\b(bug|error|crash|exception)\b/i)) labelsToAdd.add('bug'); + if (matches(/\b(new feature|enhancement|improvement|proposal|feature request)\b/i)) labelsToAdd.add('enhancement'); + if (matches(/\b(question|how to|clarify|explain|how do i|help me|question about)\b/i)) labelsToAdd.add('question'); + if (matches(/\b(documentation|docs?|readme|tutorial|wiki|typo|docstring)\b/i)) labelsToAdd.add('documentation'); + if (matches(/\b(example|sample|demo|notebook)s?\b/i)) labelsToAdd.add('examples'); if (matches(/\b(datasets?|data loader|data augmentation|data preprocessing)\b/i)) labelsToAdd.add('dataset'); if (matches(/\b(mujoco|isaac|simulation|sim)\b/i)) labelsToAdd.add('simulation'); - if (matches(/\b(train|training|loss|optimizer|backward|gradient|wandb|sac)\b/i)) labelsToAdd.add('training'); - if (matches(/\b(rerun|plot|video|render|visualiz|gif)/i)) labelsToAdd.add('visualization'); - if (matches(/\b(camera|realsense|lidar|depth|sensor|imu|microphone|rgbd)\b/i)) labelsToAdd.add('sensors'); - if (matches(/\b(aloha|koch|so-100|so100|mobile|teleop|manipulator|robots?)\b/i)) labelsToAdd.add('robots'); + if (matches(/\b(train|training|optimizer|gradient|wandb|sac)\b/i)) labelsToAdd.add('training'); + if (matches(/\b(rerun|plot|render|rendering|visualizer)/i)) labelsToAdd.add('visualization'); + if (matches(/\b(cameras?|opencv|realsense|lidars?|sensors?|imus?|microphones?|rgbd|encoders?)\b/i)) labelsToAdd.add('sensors'); + if (matches(/\b(urdf|actuators?|calibration|end-effector|kinematics)\b/i)) labelsToAdd.add('robots'); if (matches(/\b(teleop|teleoperator|controller|leader|follower|joystick|gamepad)\b/i)) labelsToAdd.add('teleoperators'); - if (matches(/\b(policy|policies|p0licy)\b/i)) labelsToAdd.add('policies'); - if (matches(/\b(processors?|pipeline)\b/i)) labelsToAdd.add('processor'); - if (matches(/\b(eval|evaluate|evaluation|metrics?|score|benchmark)\b/i)) labelsToAdd.add('evaluation'); - - // Infrastructure & Code Quality + if (matches(/\b(policy|policies|model?)\b/i)) labelsToAdd.add('policies'); + if (matches(/\b(processor|pipeline|preprocessor|postprocessor)s?\b/i)) labelsToAdd.add('processor'); + if (matches(/\b(eval|evaluate|evaluation|metrics?|score|benchmarks?)\b/i)) labelsToAdd.add('evaluation'); if (matches(/\b(tests?|pytest|unittest|failing test)\b/i)) labelsToAdd.add('tests'); - if (matches(/\b(ci|github actions|workflow|gha|actions?|pipeline)\b/i)) { - labelsToAdd.add('CI'); - labelsToAdd.add('github_actions'); - } - if (matches(/\b(perf|latency|throughput|fps|speed|performance)\b/i)) labelsToAdd.add('performance'); - if (matches(/\b(dependency|requirements|pip|conda|install error|importerror|package not found)\b/i)) labelsToAdd.add('dependencies'); - if (matches(/\b(python|pyproject|requirements(\.txt)?|pip install|typing error)\b/i)) labelsToAdd.add('python'); - - // Documentation & Meta - if (matches(/\b(doc|documentation|docs|readme|typo|how to)\b/i)) labelsToAdd.add('documentation'); - if (matches(/\b(refactor|cleanup|restructure|rename|modernize code)\b/i)) labelsToAdd.add('refactor'); - if (matches(/\b(release|changelog|version bump|cut a release|tag v)\b/i)) labelsToAdd.add('release'); - if (matches(/\b(breaking change|major change)\b/i)) labelsToAdd.add('breaking change'); + if (matches(/\b(ci|github actions?|github workflows?|gha|docker|pypi)\b/i)) labelsToAdd.add('CI'); + if (matches(/\b(perf|latency|throughput|fps|speed|performance|slow|fast|slower|faster|memory usage)\b/i)) labelsToAdd.add('performance'); + if (matches(/\b(dependency|dependencies|pip|install error|importerror|package not found|pyproject)\b/i)) labelsToAdd.add('dependencies'); + if (matches(/\b(configuration|config|arguments?|input feature|dracuss)\b/i)) labelsToAdd.add('configuration'); // Apply Labels const labels = Array.from(labelsToAdd).filter(Boolean); From a06f4b914045453b243a5d0049bbcc5f15d4dc86 Mon Sep 17 00:00:00 2001 From: Salman Chishti Date: Wed, 24 Dec 2025 09:42:29 +0000 Subject: [PATCH 005/111] Upgrade GitHub Actions for Node 24 compatibility (#2691) --- .github/workflows/fast_tests.yml | 2 +- .github/workflows/full_tests.yml | 4 ++-- .github/workflows/nightly.yml | 4 ++-- .github/workflows/quality.yml | 4 ++-- .github/workflows/release.yml | 6 +++--- .github/workflows/security.yml | 2 +- .github/workflows/unbound_deps_tests.yml | 4 ++-- 7 files changed, 13 insertions(+), 13 deletions(-) diff --git a/.github/workflows/fast_tests.yml b/.github/workflows/fast_tests.yml index ffd3195c2..10ec91199 100644 --- a/.github/workflows/fast_tests.yml +++ b/.github/workflows/fast_tests.yml @@ -62,7 +62,7 @@ jobs: HF_HOME: /mnt/cache/.cache/huggingface HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: persist-credentials: false lfs: true diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml index 7962d4a3a..7bf2cb78c 100644 --- a/.github/workflows/full_tests.yml +++ b/.github/workflows/full_tests.yml @@ -61,7 +61,7 @@ jobs: HF_HOME: /mnt/cache/.cache/huggingface HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: lfs: true persist-credentials: false @@ -127,7 +127,7 @@ jobs: sudo apt-get update sudo apt-get install git-lfs git lfs install - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: lfs: true persist-credentials: false diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index 94d5cc9f2..45bfb9bd5 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -52,7 +52,7 @@ jobs: sudo apt-get update sudo apt-get install git-lfs git lfs install - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: lfs: true persist-credentials: false @@ -87,7 +87,7 @@ jobs: sudo apt-get update sudo apt-get install git-lfs git lfs install - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: lfs: true persist-credentials: false diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml index e9f73ed23..0dc94cdd4 100644 --- a/.github/workflows/quality.yml +++ b/.github/workflows/quality.yml @@ -43,12 +43,12 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: persist-credentials: false - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.10' diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c61307d6c..bcab4c262 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -38,12 +38,12 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: persist-credentials: false - name: Set up Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: '3.10' @@ -135,7 +135,7 @@ jobs: env: MUJOCO_GL: egl steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: lfs: true persist-credentials: false diff --git a/.github/workflows/security.yml b/.github/workflows/security.yml index 04497307b..50c0c1fc3 100644 --- a/.github/workflows/security.yml +++ b/.github/workflows/security.yml @@ -43,7 +43,7 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v4 # zizmor: ignore[unpinned-uses] + uses: actions/checkout@v6 # zizmor: ignore[unpinned-uses] with: fetch-depth: 0 persist-credentials: false diff --git a/.github/workflows/unbound_deps_tests.yml b/.github/workflows/unbound_deps_tests.yml index bf93dcda6..3908bdc3d 100644 --- a/.github/workflows/unbound_deps_tests.yml +++ b/.github/workflows/unbound_deps_tests.yml @@ -49,7 +49,7 @@ jobs: HF_HOME: /mnt/cache/.cache/huggingface HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: lfs: true persist-credentials: false @@ -101,7 +101,7 @@ jobs: sudo apt-get update sudo apt-get install git-lfs git lfs install - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: lfs: true persist-credentials: false From 12043b3b5cfd104c7d6201992d761fa39576fb5c Mon Sep 17 00:00:00 2001 From: Alexis Alva Date: Wed, 24 Dec 2025 11:45:14 -0300 Subject: [PATCH 006/111] fix: use importlib.metadata for plugin discovery to support PEP 660 (#2687) --- src/lerobot/utils/import_utils.py | 34 ++++++++++++++++++------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index 0dd9db516..3a01aee88 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -14,8 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib +import importlib.metadata import logging -import pkgutil from typing import Any from draccus.choice_types import ChoiceRegistry @@ -132,24 +132,30 @@ def make_device_from_device_class(config: ChoiceRegistry) -> Any: def register_third_party_plugins() -> None: """ - Discover and import third-party lerobot_* plugins so they can register themselves. + Discover and import third-party LeRobot plugins so they can register themselves. - Scans top-level modules on sys.path for packages starting with - 'lerobot_robot_', 'lerobot_camera_', 'lerobot_teleoperator_' or 'lerobot_policy_' and imports them. + This function uses `importlib.metadata` to find packages installed in the environment + (including editable installs) starting with 'lerobot_robot_', 'lerobot_camera_', + 'lerobot_teleoperator_', or 'lerobot_policy_' and imports them. """ prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_", "lerobot_policy_") imported: list[str] = [] failed: list[str] = [] - for module_info in pkgutil.iter_modules(): - name = module_info.name - if name.startswith(prefixes): - try: - importlib.import_module(name) - imported.append(name) - logging.info("Imported third-party plugin: %s", name) - except Exception: - logging.exception("Could not import third-party plugin: %s", name) - failed.append(name) + def attempt_import(module_name: str): + try: + importlib.import_module(module_name) + imported.append(module_name) + logging.info("Imported third-party plugin: %s", module_name) + except Exception: + logging.exception("Could not import third-party plugin: %s", module_name) + failed.append(module_name) + + for dist in importlib.metadata.distributions(): + dist_name = dist.metadata.get("Name") + if not dist_name: + continue + if dist_name.startswith(prefixes): + attempt_import(dist_name) logging.debug("Third-party plugin import summary: imported=%s failed=%s", imported, failed) From 60efd875fad262d1343e14867bd0ad21fbbe862f Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Fri, 26 Dec 2025 23:57:17 +0100 Subject: [PATCH 007/111] resolve path correctlt (#2710) Co-authored-by: Steven Palma --- src/lerobot/utils/rabc.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/lerobot/utils/rabc.py b/src/lerobot/utils/rabc.py index c529f3ccc..dc0c61c69 100644 --- a/src/lerobot/utils/rabc.py +++ b/src/lerobot/utils/rabc.py @@ -20,6 +20,18 @@ from pathlib import Path import numpy as np import pandas as pd import torch +from huggingface_hub import hf_hub_download + + +def resolve_hf_path(path: str | Path) -> Path: + """Resolve a path that may be a HuggingFace URL (hf://datasets/...) to a local path.""" + path_str = str(path) + if path_str.startswith("hf://datasets/"): + parts = path_str.replace("hf://datasets/", "").split("/") + repo_id = "/".join(parts[:2]) + filename = "/".join(parts[2:]) + return Path(hf_hub_download(repo_id=repo_id, filename=filename, repo_type="dataset")) + return Path(path) class RABCWeights: @@ -51,7 +63,7 @@ class RABCWeights: fallback_weight: float = 1.0, device: torch.device = None, ): - self.progress_path = Path(progress_path) + self.progress_path = resolve_hf_path(progress_path) self.chunk_size = chunk_size self.head_mode = head_mode self.kappa = kappa From 6d0d65a5fec490c966b186bf9455ed3ea82f3f52 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Sun, 28 Dec 2025 01:45:06 +0100 Subject: [PATCH 008/111] chore: adds dynamic README handling and setup script (#2724) --- docker/Dockerfile.internal | 2 +- docker/Dockerfile.user | 2 +- pyproject.toml | 2 +- setup.py | 72 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 75 insertions(+), 3 deletions(-) create mode 100644 setup.py diff --git a/docker/Dockerfile.internal b/docker/Dockerfile.internal index 2616cd06c..c1dfa1dae 100644 --- a/docker/Dockerfile.internal +++ b/docker/Dockerfile.internal @@ -73,7 +73,7 @@ ENV HOME=/home/user_lerobot \ RUN uv venv --python python${PYTHON_VERSION} # Install Python dependencies for caching -COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md MANIFEST.in ./ +COPY --chown=user_lerobot:user_lerobot setup.py pyproject.toml README.md MANIFEST.in ./ COPY --chown=user_lerobot:user_lerobot src/ src/ ARG UNBOUND_DEPS=false diff --git a/docker/Dockerfile.user b/docker/Dockerfile.user index c1b284453..031165930 100644 --- a/docker/Dockerfile.user +++ b/docker/Dockerfile.user @@ -59,7 +59,7 @@ ENV HOME=/home/user_lerobot \ RUN uv venv # Install Python dependencies for caching -COPY --chown=user_lerobot:user_lerobot pyproject.toml README.md MANIFEST.in ./ +COPY --chown=user_lerobot:user_lerobot setup.py pyproject.toml README.md MANIFEST.in ./ COPY --chown=user_lerobot:user_lerobot src/ src/ ARG UNBOUND_DEPS=false diff --git a/pyproject.toml b/pyproject.toml index 61b802bc5..33586c5e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ discord = "https://discord.gg/s3KuuzsPFb" name = "lerobot" version = "0.4.3" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" -readme = "README.md" +dynamic = ["readme"] license = { text = "Apache-2.0" } requires-python = ">=3.10" authors = [ diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..d97b6f835 --- /dev/null +++ b/setup.py @@ -0,0 +1,72 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from setuptools import setup + + +def get_version_from_toml() -> str: + """Return the project's version string parsed from `pyproject.toml`. + + The function scans `pyproject.toml` line-by-line looking for a line + that starts with ``version`` (for example: ``version = "1.2.3"``) + and returns the value without surrounding quotes. If no such line is + found a :class:`ValueError` is raised. + + Returns: + The version string from `pyproject.toml` (e.g. ``"1.2.3"`` -> + ``1.2.3``). + """ + + version = None + with open("pyproject.toml", encoding="utf-8") as f: + for line in f: + if line.strip().startswith("version"): + version = line.split("=")[1].strip().strip('"') + break + if version is None: + raise ValueError("Version not found in pyproject.toml") + return version + + +def read_long_description() -> str: + """Read and return the project's long description for setup. + + This function reads `README.md` and replaces image links that point + to the local `./media/` directory with absolute raw GitHub URLs that + reference the release tag corresponding to the version parsed from + `pyproject.toml` (for example, ``v1.2.3``). The modified README + content is returned as a string suitable for passing to + ``setuptools.setup(long_description=...)``. + + Returns: + The README content with rewritten media links. + """ + + with open("README.md", encoding="utf-8") as f: + content = f.read() + + version = get_version_from_toml() + git_tag = f"v{version}" + + base_raw_url = f"https://raw.githubusercontent.com/huggingface/lerobot/{git_tag}/" + content = content.replace('src="./media/', f'src="{base_raw_url}media/') + + return content + + +setup( + long_description=read_long_description(), + long_description_content_type="text/markdown", +) From 9701b9c2735a989a6f07b5b3128db73408245c72 Mon Sep 17 00:00:00 2001 From: arya <106710524+pxe3@users.noreply.github.com> Date: Wed, 31 Dec 2025 06:54:28 -0800 Subject: [PATCH 009/111] feat(pi0): add train_expert_only and freeze_vision_encoder flags to pi0 and pi0.5 (#2727) * feat(pi0): add train_expert_only and freeze_vision_encoder options * pi_05: train_expert_only and freeze_vision_encoder flags * comment clean up * docs: add finetuning parameters to pi0 and pi05 docs * updating docs to follow standards --- docs/source/pi0.mdx | 11 +++++++++ docs/source/pi05.mdx | 11 +++++++++ src/lerobot/policies/pi0/configuration_pi0.py | 4 ++++ src/lerobot/policies/pi0/modeling_pi0.py | 24 +++++++++++++++++++ .../policies/pi05/configuration_pi05.py | 4 ++++ src/lerobot/policies/pi05/modeling_pi05.py | 24 +++++++++++++++++++ 6 files changed, 78 insertions(+) diff --git a/docs/source/pi0.mdx b/docs/source/pi0.mdx index d15f7e91f..89604b6aa 100644 --- a/docs/source/pi0.mdx +++ b/docs/source/pi0.mdx @@ -64,6 +64,8 @@ python src/lerobot/scripts/lerobot_train.py \ --policy.compile_model=true \ --policy.gradient_checkpointing=true \ --policy.dtype=bfloat16 \ + --policy.freeze_vision_encoder=false \ + --policy.train_expert_only=false \ --steps=3000 \ --policy.device=cuda \ --batch_size=32 @@ -79,6 +81,15 @@ python src/lerobot/scripts/lerobot_train.py \ - [lerobot/pi0_base](https://huggingface.co/lerobot/pi0_base) - [lerobot/pi0_libero](https://huggingface.co/lerobot/pi0_libero) (specifically trained on the Libero dataset) +### Training Parameters Explained + +| Parameter | Default | Description | +| ----------------------- | ------- | ------------------------------------------- | +| `freeze_vision_encoder` | `false` | Do not freeze the vision encoder | +| `train_expert_only` | `false` | Do not freeze the VLM, train all parameters | + +**💡 Tip**: Setting `train_expert_only=true` freezes the VLM and trains only the action expert and projections, allowing finetuning with reduced memory usage. + ## License This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi). diff --git a/docs/source/pi05.mdx b/docs/source/pi05.mdx index 29b797935..dbf118aa3 100644 --- a/docs/source/pi05.mdx +++ b/docs/source/pi05.mdx @@ -67,6 +67,8 @@ python src/lerobot/scripts/lerobot_train.py\ --policy.gradient_checkpointing=true \ --wandb.enable=true \ --policy.dtype=bfloat16 \ + --policy.freeze_vision_encoder=false \ + --policy.train_expert_only=false \ --steps=3000 \ --policy.device=cuda \ --batch_size=32 @@ -82,6 +84,15 @@ python src/lerobot/scripts/lerobot_train.py\ - [lerobot/pi05_base](https://huggingface.co/lerobot/pi05_base) - [lerobot/pi05_libero](https://huggingface.co/lerobot/pi05_libero) (specifically trained on the Libero dataset) +### Training Parameters Explained + +| Parameter | Default | Description | +| ----------------------- | ------- | ------------------------------------------- | +| `freeze_vision_encoder` | `false` | Do not freeze the vision encoder | +| `train_expert_only` | `false` | Do not freeze the VLM, train all parameters | + +**💡 Tip**: Setting `train_expert_only=true` freezes the VLM and trains only the action expert and projections, allowing finetuning with reduced memory usage. + If your dataset is not converted with `quantiles`, you can convert it with the following command: ```bash diff --git a/src/lerobot/policies/pi0/configuration_pi0.py b/src/lerobot/policies/pi0/configuration_pi0.py index 33753e0b2..d91145aa7 100644 --- a/src/lerobot/policies/pi0/configuration_pi0.py +++ b/src/lerobot/policies/pi0/configuration_pi0.py @@ -76,6 +76,10 @@ class PI0Config(PreTrainedConfig): compile_mode: str = "max-autotune" # Torch compile mode device: str | None = None # Device to use for the model (None = auto-detect) + # Finetuning settings + freeze_vision_encoder: bool = False # Freeze only the vision encoder + train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections + # Optimizer settings: see openpi `AdamW`` optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr` optimizer_betas: tuple[float, float] = (0.9, 0.95) diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index e4970dcf1..0445d6c00 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -339,10 +339,14 @@ class PaliGemmaWithExpertModel( use_adarms=None, precision: Literal["bfloat16", "float32"] = "bfloat16", image_size: int = DEFAULT_IMAGE_SIZE, + freeze_vision_encoder: bool = False, + train_expert_only: bool = False, ): if use_adarms is None: use_adarms = [False, False] super().__init__() + self.freeze_vision_encoder = freeze_vision_encoder + self.train_expert_only = train_expert_only vlm_config_hf = CONFIG_MAPPING["paligemma"]() vlm_config_hf._vocab_size = 257152 # noqa: SLF001 @@ -383,6 +387,7 @@ class PaliGemmaWithExpertModel( self.gemma_expert.model.embed_tokens = None self.to_bfloat16_for_selected_params(precision) + self._set_requires_grad() def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): if precision == "bfloat16": @@ -406,6 +411,23 @@ class PaliGemmaWithExpertModel( if any(selector in name for selector in params_to_keep_float32): param.data = param.data.to(dtype=torch.float32) + def _set_requires_grad(self): + if self.freeze_vision_encoder: + self.paligemma.vision_tower.eval() + for param in self.paligemma.vision_tower.parameters(): + param.requires_grad = False + if self.train_expert_only: + self.paligemma.eval() + for param in self.paligemma.parameters(): + param.requires_grad = False + + def train(self, mode: bool = True): + super().train(mode) + if self.freeze_vision_encoder: + self.paligemma.vision_tower.eval() + if self.train_expert_only: + self.paligemma.eval() + def embed_image(self, image: torch.Tensor): return self.paligemma.model.get_image_features(image) @@ -533,6 +555,8 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` use_adarms=[False, False], precision=config.dtype, image_size=config.image_resolution[0], + freeze_vision_encoder=config.freeze_vision_encoder, + train_expert_only=config.train_expert_only, ) self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width) diff --git a/src/lerobot/policies/pi05/configuration_pi05.py b/src/lerobot/policies/pi05/configuration_pi05.py index 7bdce70dd..4a0e23039 100644 --- a/src/lerobot/policies/pi05/configuration_pi05.py +++ b/src/lerobot/policies/pi05/configuration_pi05.py @@ -76,6 +76,10 @@ class PI05Config(PreTrainedConfig): compile_mode: str = "max-autotune" # Torch compile mode device: str | None = None # Device to use for the model (None = auto-detect) + # Finetuning settings + freeze_vision_encoder: bool = False # Freeze only the vision encoder + train_expert_only: bool = False # Freeze entire VLM, train only action expert and projections + # Optimizer settings: see openpi `AdamW` optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr` optimizer_betas: tuple[float, float] = (0.9, 0.95) diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 2cd142042..11d8b4d68 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -337,10 +337,14 @@ class PaliGemmaWithExpertModel( use_adarms=None, precision: Literal["bfloat16", "float32"] = "bfloat16", image_size: int = DEFAULT_IMAGE_SIZE, + freeze_vision_encoder: bool = False, + train_expert_only: bool = False, ): if use_adarms is None: use_adarms = [False, False] super().__init__() + self.freeze_vision_encoder = freeze_vision_encoder + self.train_expert_only = train_expert_only vlm_config_hf = CONFIG_MAPPING["paligemma"]() vlm_config_hf._vocab_size = 257152 # noqa: SLF001 @@ -381,6 +385,7 @@ class PaliGemmaWithExpertModel( self.gemma_expert.model.embed_tokens = None self.to_bfloat16_for_selected_params(precision) + self._set_requires_grad() def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): if precision == "bfloat16": @@ -404,6 +409,23 @@ class PaliGemmaWithExpertModel( if any(selector in name for selector in params_to_keep_float32): param.data = param.data.to(dtype=torch.float32) + def _set_requires_grad(self): + if self.freeze_vision_encoder: + self.paligemma.vision_tower.eval() + for param in self.paligemma.vision_tower.parameters(): + param.requires_grad = False + if self.train_expert_only: + self.paligemma.eval() + for param in self.paligemma.parameters(): + param.requires_grad = False + + def train(self, mode: bool = True): + super().train(mode) + if self.freeze_vision_encoder: + self.paligemma.vision_tower.eval() + if self.train_expert_only: + self.paligemma.eval() + def embed_image(self, image: torch.Tensor): return self.paligemma.model.get_image_features(image) @@ -531,6 +553,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` use_adarms=[False, True], precision=config.dtype, image_size=config.image_resolution[0], + freeze_vision_encoder=config.freeze_vision_encoder, + train_expert_only=config.train_expert_only, ) self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width) From fc296548cb6438b3036d046b43fe91951c87ea9a Mon Sep 17 00:00:00 2001 From: Kartik Date: Fri, 2 Jan 2026 20:36:24 +0100 Subject: [PATCH 010/111] feat(envs): Add NVIDIA IsaacLab-Arena Lerobot (#2699) * adding Isaaclab Arena from collab * adding into lerobot-eval * minor modification * added bash script for env setup * setups * fix applauncher not getting the arguments * data conversion, train and eval smolvla * fixed imports * clean-up * added test suits & clean up - wip * fixed video recording * clean-up * hub integration working * clean-up * added kwargs * Revert "added kwargs" This reverts commit 9b445356385d0707655cf04d02be058b25138119. * added kwargs * clean-up * cleaned unused function * added logging * docs * cleaned up IsaaclabArenaEnv * clean-up * clean-up * clean up * added tests * minor clean-up * fix: support for state based envs * feat(envs): Add NVIDIA IsaacLab Arena integration with LeRobot for policy evaluation at scale * feat(envs): Add IsaacLab Arena integration for policy evaluation Integrate NVIDIA IsaacLab Arena with LeRobot to enable GPU-accelerated simulation through the EnvHub infrastructure. This enables: - Training imitation learning policies (PI0, SmolVLA, etc.) - Evaluating trained policies in with IsaacLab Arena The implementation adds: - IsaaclabArenaEnv config with Arena-specific parameters - IsaaclabArenaProcessorStep for observation processing - Hub loading from nvkartik/isaaclab-arena-envs repository - Video recording support Available environments include GR1 microwave manipulation, Galileo pick-and-place, G1 loco-manipulation, and button pressing tasks. Datasets: nvkartik/Arena-GR1-Manipulation-Task Policies: nvkartik/pi05-arena-gr1-microwave, nvkartik/smolvla-arena-gr1-microwave * added isaaclab arena wrapper and corresponding tests * added error handling * renamed wrapper file: isaaclab_arena to isaaclab * added extra kwarg changes * adjustments for hub envs * correct class name in test file * fixed parsing of env_kwargs * tested end to end * removed unused code * refactor design * shifted IsaacLab to hub * removed IsaacLab tests * docs: Add LW-BenchHub evaluation instructions * docs: Add LW-BenchHub evaluation instructions * docs diet * minor edits to texts * IL Arena commit hash * update links * minor edits * fix numpy version after install of lerobot * links update * valideated on vanilla brev * docs: Add LW-BenchHub evaluation instructions * remove kwargs from all make_env calls * remove kwargs from all make_env calls * fix LW table and indentations * remove environment list from docs * docs: Update lw-benchhub eval config in envhub docs * removing kwargs * removed extra line * ensure pinocchio install for lightwheel + add lightwheel website link * remove env_kwargs * no default empty value for hub_path * not using assert method * remove env_processor defaults * revert and adding default "" value for hub_path * pinning down packages versions * explicit None value for hub_path * Update src/lerobot/configs/eval.py Co-authored-by: Jade Choghari Signed-off-by: Lior Ben Horin * corrected formatting * corrected job_name var in config * updated docs and namespace * updated namespace * updated docs * updated docs * added hardware requirements * updated docs --------- Signed-off-by: Lior Ben Horin Co-authored-by: lbenhorin Co-authored-by: Lior Ben Horin Co-authored-by: Jade Choghari Co-authored-by: Steven Palma Co-authored-by: tianheng.wu --- docs/source/_toctree.yml | 2 + docs/source/envhub_isaaclab_arena.mdx | 474 +++++++++++++++++++++++++ src/lerobot/configs/eval.py | 2 + src/lerobot/envs/__init__.py | 2 +- src/lerobot/envs/configs.py | 86 ++++- src/lerobot/envs/factory.py | 45 ++- src/lerobot/envs/utils.py | 20 +- src/lerobot/processor/env_processor.py | 77 +++- src/lerobot/scripts/lerobot_eval.py | 7 +- tests/processor/test_pipeline.py | 4 +- 10 files changed, 704 insertions(+), 15 deletions(-) create mode 100644 docs/source/envhub_isaaclab_arena.mdx diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 7766b3472..b59d15734 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -59,6 +59,8 @@ title: Environments from the Hub - local: envhub_leisaac title: Control & Train Robots in Sim (LeIsaac) + - local: envhub_isaaclab_arena + title: NVIDIA IsaacLab Arena Environments - local: libero title: Using Libero - local: metaworld diff --git a/docs/source/envhub_isaaclab_arena.mdx b/docs/source/envhub_isaaclab_arena.mdx new file mode 100644 index 000000000..454a003a0 --- /dev/null +++ b/docs/source/envhub_isaaclab_arena.mdx @@ -0,0 +1,474 @@ +# NVIDIA IsaacLab Arena & LeRobot + +LeRobot EnvHub now supports **GPU-accelerated simulation** with IsaacLab Arena for policy evaluation at scale. +Train and evaluate imitation learning policies with high-fidelity simulation — all integrated into the LeRobot ecosystem. + +IsaacLab Arena - GR1 Microwave Environment + +[IsaacLab Arena](https://github.com/isaac-sim/IsaacLab-Arena) integrates with NVIDIA IsaacLab to provide: + +- 🤖 **Humanoid embodiments**: GR1, G1, Galileo with various configurations +- 🎯 **Manipulation & loco-manipulation tasks**: Microwave opening, pick-and-place, button pressing +- ⚡ **GPU-accelerated rollouts**: Parallel environment execution on NVIDIA GPUs +- 🖼️ **RTX Rendering**: Evaluate vision-based policies with realistic rendering, reflections and refractions. +- 📦 **LeRobot-compatible datasets**: Ready for training with GR00T Nx, PI0, SmolVLA, ACT, Diffusion policies +- 🔄 **EnvHub integration**: Load environments from HuggingFace Hub with one line + +## Installation + +### Prerequisites + +Hardware requirements are shared with Isaac Sim, and are detailed in [Isaac Sim Requirements](https://docs.isaacsim.omniverse.nvidia.com/5.1.0/installation/requirements.html). + +- NVIDIA GPU with CUDA support +- NVIDIA driver compatible with IsaacSim 5.1.0 +- Linux (Ubuntu 22.04 / 24.04) + +### Setup + +```bash +# 1. Create conda environment +conda create -y -n lerobot-arena python=3.11 +conda activate lerobot-arena +conda install -y -c conda-forge ffmpeg=7.1.1 + +# 2. Install Isaac Sim 5.1.0 +pip install "isaacsim[all,extscache]==5.1.0" --extra-index-url https://pypi.nvidia.com + +# Accept NVIDIA EULA (required) +export ACCEPT_EULA=Y +export PRIVACY_CONSENT=Y + +# 3. Install IsaacLab 2.3.0 +git clone https://github.com/isaac-sim/IsaacLab.git +cd IsaacLab +git checkout v2.3.0 +./isaaclab.sh -i +cd .. + +# 4. Install IsaacLab Arena +git clone https://github.com/isaac-sim/IsaacLab-Arena.git +cd IsaacLab-Arena +git checkout release/0.1.1 +pip install -e . +cd .. + + +# 5. Install LeRobot +git clone https://github.com/huggingface/lerobot.git +cd lerobot +pip install -e . +cd .. + + +# 6. Install additional dependencies +pip install onnxruntime==1.23.2 lightwheel-sdk==1.0.1 vuer[all]==0.0.70 qpsolvers==4.8.1 +pip install numpy==1.26.0 # Isaac Sim 5.1 depends on numpy==1.26.0, this will be fixed in next release +``` + +## Evaluating Policies + +### Pre-trained Policies + +The following trained policies are available: + +| Policy | Architecture | Task | Link | +| :-------------------------- | :----------- | :------------ | :----------------------------------------------------------------------- | +| pi05-arena-gr1-microwave | PI0.5 | GR1 Microwave | [HuggingFace](https://huggingface.co/nvidia/pi05-arena-gr1-microwave) | +| smolvla-arena-gr1-microwave | SmolVLA | GR1 Microwave | [HuggingFace](https://huggingface.co/nvidia/smolvla-arena-gr1-microwave) | + +### Evaluate SmolVLA + +```bash +pip install -e ".[smolvla]" +pip install numpy==1.26.0 # revert to numpy version is 1.26 +``` + +```bash +lerobot-eval \ + --policy.path=nvidia/smolvla-arena-gr1-microwave \ + --env.type=isaaclab_arena \ + --env.hub_path=nvidia/isaaclab-arena-envs \ + --rename_map='{"observation.images.robot_pov_cam_rgb": "observation.images.robot_pov_cam"}' \ + --policy.device=cuda \ + --env.environment=gr1_microwave \ + --env.embodiment=gr1_pink \ + --env.object=mustard_bottle \ + --env.headless=false \ + --env.enable_cameras=true \ + --env.video=true \ + --env.video_length=10 \ + --env.video_interval=15 \ + --env.state_keys=robot_joint_pos \ + --env.camera_keys=robot_pov_cam_rgb \ + --trust_remote_code=True \ + --eval.batch_size=1 +``` + +### Evaluate PI0.5 + +```bash +pip install -e ".[pi]" +pip install numpy==1.26.0 # revert to numpy version is 1.26 +``` + +PI0.5 requires disabling torch compile for evaluation: + +```bash +TORCH_COMPILE_DISABLE=1 TORCHINDUCTOR_DISABLE=1 lerobot-eval \ + --policy.path=nvidia/pi05-arena-gr1-microwave \ + --env.type=isaaclab_arena \ + --env.hub_path=nvidia/isaaclab-arena-envs \ + --rename_map='{"observation.images.robot_pov_cam_rgb": "observation.images.robot_pov_cam"}' \ + --policy.device=cuda \ + --env.environment=gr1_microwave \ + --env.embodiment=gr1_pink \ + --env.object=mustard_bottle \ + --env.headless=false \ + --env.enable_cameras=true \ + --env.video=true \ + --env.video_length 15 \ + --env.video_interval 15 \ + --env.state_keys=robot_joint_pos \ + --env.camera_keys=robot_pov_cam_rgb \ + --trust_remote_code=True \ + --eval.batch_size=1 +``` + + + To change the number of parallel environments, use the ```--eval.batch_size``` + flag. + + +### What to Expect + +During evaluation, you will see a progress bar showing the running success rate: + +``` +Stepping through eval batches: 8%|██████▍ | 4/50 [00:45<08:06, 10.58s/it, running_success_rate=25.0%] +``` + +### Video Recording + +To enable video recording during evaluation, add the following flags to your command: + +```bash +--env.video=true \ +--env.video_length=15 \ +--env.video_interval=15 +``` + +For more details on video recording, see the [IsaacLab Recording Documentation](https://isaac-sim.github.io/IsaacLab/main/source/how-to/record_video.html). + + +When running headless with `--env.headless=true`, you must also enable cameras explicitly for camera enabled environments: + +```bash +--env.headless=true --env.enable_cameras=true +``` + + + +### Output Directory + +Evaluation videos are saved to the output directory with the following structure: + +``` +outputs/eval//__/videos/_/eval_episode_.mp4 +``` + +For example: + +``` +outputs/eval/2026-01-02/14-38-01_isaaclab_arena_smolvla/videos/gr1_microwave_0/eval_episode_0.mp4 +``` + +## Training Policies + +To learn more about training policies with LeRobot, please refer to training documentation: + +- [SmolVLA](./smolvla) +- [Pi0.5](./pi05) +- [GR00T N1.5](./groot) + +Sample IsaacLab Arena datasets are available on HuggingFace Hub for experimentation: + +| Dataset | Description | Frames | +| :-------------------------------------------------------------------------------------------------------- | :------------------------- | :----- | +| [Arena-GR1-Manipulation-Task](https://huggingface.co/datasets/nvidia/Arena-GR1-Manipulation-Task-v3) | GR1 microwave manipulation | ~4K | +| [Arena-G1-Loco-Manipulation-Task](https://huggingface.co/datasets/nvidia/Arena-G1-Loco-Manipulation-Task) | G1 loco-manipulation | ~4K | + +## Environment Configuration + +### Full Configuration Options + +```python +from lerobot.envs.configs import IsaaclabArenaEnv + +config = IsaaclabArenaEnv( + # Environment selection + environment="gr1_microwave", # Task environment + embodiment="gr1_pink", # Robot embodiment + object="power_drill", # Object to manipulate + + # Simulation settings + episode_length=300, # Max steps per episode + headless=True, # Run without GUI + device="cuda:0", # GPU device + seed=42, # Random seed + + # Observation configuration + state_keys="robot_joint_pos", # State observation keys (comma-separated) + camera_keys="robot_pov_cam_rgb", # Camera observation keys (comma-separated) + state_dim=54, # Expected state dimension + action_dim=36, # Expected action dimension + camera_height=512, # Camera image height + camera_width=512, # Camera image width + enable_cameras=True, # Enable camera observations + + # Video recording + video=False, # Enable video recording + video_length=100, # Frames per video + video_interval=200, # Steps between recordings + + # Advanced + mimic=False, # Enable mimic mode + teleop_device=None, # Teleoperation device + disable_fabric=False, # Disable fabric optimization + enable_pinocchio=True, # Enable Pinocchio for IK +) +``` + +### Using Environment Hub directly for advanced usage + +Create a file called `test_env_load_arena.py` or [download from the EnvHub](https://huggingface.co/nvidia/isaaclab-arena-envs/blob/main/tests/test_env_load_arena.py): + +```python +import logging +from dataclasses import asdict +from pprint import pformat +import torch +import tqdm +from lerobot.configs import parser +from lerobot.configs.eval import EvalPipelineConfig + + +@parser.wrap() +def main(cfg: EvalPipelineConfig): + """Run zero action rollout for IsaacLab Arena environment.""" + logging.info(pformat(asdict(cfg))) + + from lerobot.envs.factory import make_env + + env_dict = make_env( + cfg.env, + n_envs=cfg.env.num_envs, + trust_remote_code=True, + ) + env = next(iter(env_dict.values()))[0] + env.reset() + for _ in tqdm.tqdm(range(cfg.env.episode_length)): + with torch.inference_mode(): + actions = env.action_space.sample() + obs, rewards, terminated, truncated, info = env.step(actions) + if terminated.any() or truncated.any(): + obs, info = env.reset() + env.close() + + +if __name__ == "__main__": + main() +``` + +Run with: + +```bash +python test_env_load_arena.py \ + --env.environment=g1_locomanip_pnp \ + --env.embodiment=gr1_pink \ + --env.object=cracker_box \ + --env.num_envs=4 \ + --env.enable_cameras=true \ + --env.seed=1000 \ + --env.video=true \ + --env.video_length=10 \ + --env.video_interval=15 \ + --env.headless=false \ + --env.hub_path=nvidia/isaaclab-arena-envs \ + --env.type=isaaclab_arena +``` + +## Troubleshooting + +### CUDA out of memory + +Reduce `batch_size` or use a GPU with more VRAM: + +```bash +--eval.batch_size=1 +``` + +### EULA not accepted + +Set environment variables before running: + +```bash +export ACCEPT_EULA=Y +export PRIVACY_CONSENT=Y +``` + +### Video recording not working + +Enable cameras when running headless: + +```bash +--env.video=true --env.enable_cameras=true --env.headless=true +``` + +### Policy output dimension mismatch + +E.g. ensure `action_dim` matches your policy: + +```bash +--env.action_dim=36 +``` + +### libGLU.so.1 Errors during Isaac Sim initialization + +Ensure you have the following dependencies installed, this is likely to happen on headless machines. + +```bash +sudo apt update && sudo apt install -y libglu1-mesa libxt6 +``` + +## See Also + +- [EnvHub Documentation](./envhub.mdx) - General EnvHub usage +- [IsaacLab Arena GitHub](https://github.com/isaac-sim/IsaacLab-Arena) +- [IsaacLab Documentation](https://isaac-sim.github.io/IsaacLab/) + +## LightWheel LW-BenchHub + +[LightWheel AI](https://www.lightwheel.ai) are bringing `Lightwheel-Libero-Tasks` and `Lightwheel-RoboCasa-Tasks` with 268 tasks to the LeRobot ecosystem. +LW-BenchHub collects and generates large-scale datasets via teleoperation that comply with the LeRobot specification, enabling out-of-the-box training and evaluation workflows. +With the unified interface provided by EnvHub, developers can quickly build end-to-end experimental pipelines. + +### Install + +Assuming you followed the [Installation](#installation) steps, you can install LW-BenchHub with: + +```bash +conda install pinocchio -c conda-forge -y +git clone https://github.com/LightwheelAI/lw_benchhub +cd lw_benchhub +pip install -e . +``` + +For more detailed instructions, please refer to the [LW-BenchHub Documentation](https://docs.lightwheel.net/lw_benchhub/usage/Installation). + +### Lightwheel Tasks Dataset + +LW-BenchHub datasets are available on HuggingFace Hub: + +| Dataset | Description | Tasks | Frames | +| :------------------------------------------------------------------------------------------------------------ | :---------------------- | :---- | :----- | +| [Lightwheel-Tasks-X7S](https://huggingface.co/datasets/LightwheelAI/Lightwheel-Tasks-X7S) | X7S LIBERO and RoboCasa | 117 | ~10.3M | +| [Lightwheel-Tasks-Double-Piper](https://huggingface.co/datasets/LightwheelAI/Lightwheel-Tasks-Double-Piper) | Double-Piper LIBERO | 130 | ~6.0M | +| [Lightwheel-Tasks-G1-Controller](https://huggingface.co/datasets/LightwheelAI/Lightwheel-Tasks-G1-Controller) | G1-Controller LIBERO | 62 | ~2.7M | +| [Lightwheel-Tasks-G1-WBC](https://huggingface.co/datasets/LightwheelAI/Lightwheel-Tasks-G1-WBC) | G1-WBC RoboCasa | 32 | ~1.5M | + +For training policies, refer to the [Training Policies](#training-policies) section. + +### Evaluating Policies + +#### Pre-trained Policies + +The following trained policies are available: + +| Policy | Architecture | Task | Layout | Robot | Link | +| :----------------------- | :----------- | :----------------------------- | :--------- | :-------------- | :------------------------------------------------------------------------------------ | +| smolvla-double-piper-pnp | SmolVLA | L90K1PutTheBlackBowlOnThePlate | libero-1-1 | DoublePiper-Abs | [HuggingFace](https://huggingface.co/LightwheelAI/smolvla-double-piper-pnp/tree/main) | + +#### Evaluate SmolVLA + +```bash +lerobot-eval \ + --policy.path=LightwheelAI/smolvla-double-piper-pnp \ + --env.type=isaaclab_arena \ + --rename_map='{"observation.images.left_hand_camera_rgb": "observation.images.left_hand", "observation.images.right_hand_camera_rgb": "observation.images.right_hand", "observation.images.first_person_camera_rgb": "observation.images.first_person"}' \ + --env.hub_path=LightwheelAI/lw_benchhub_env \ + --env.kwargs='{"config_path": "configs/envhub/example.yml"}' \ + --trust_remote_code=true \ + --env.state_keys=joint_pos \ + --env.action_dim=12 \ + --env.camera_keys=left_hand_camera_rgb,right_hand_camera_rgb,first_person_camera_rgb \ + --policy.device=cuda \ + --eval.batch_size=10 \ + --eval.n_episodes=100 +``` + +### Environment Configuration + +Evaluation can be quickly launched by modifying the `robot`, `task`, and `layout` settings in the configuration file. + +#### Full Configuration Options + +```yml +# ========================= +# Basic Settings +# ========================= +disable_fabric: false +device: cuda:0 +sensitivity: 1.0 +step_hz: 50 +enable_cameras: true +execute_mode: eval +episode_length_s: 20.0 # Episode length in seconds, increase if episodes timeout during eval + +# ========================= +# Robot Settings +# ========================= +robot: DoublePiper-Abs # Robot type, DoublePiper-Abs, X7S-Abs, G1-Controller or G1-Controller-DecoupledWBC +robot_scale: 1.0 + +# ========================= +# Task & Scene Settings +# ========================= +task: L90K1PutTheBlackBowlOnThePlate # Task name +scene_backend: robocasa +task_backend: robocasa +debug_assets: null +layout: libero-1-1 # Layout and style ID +sources: + - objaverse + - lightwheel + - aigen_objs +object_projects: [] +usd_simplify: false +seed: 42 + +# ========================= +# Object Placement Retry Settings +# ========================= +max_scene_retry: 4 +max_object_placement_retry: 3 + +resample_objects_placement_on_reset: true +resample_robot_placement_on_reset: true + +# ========================= +# Replay Configuration Settings +# ========================= +replay_cfgs: + add_camera_to_observation: true + render_resolution: [640, 480] +``` + +### See Also + +- [LW-BenchHub GitHub](https://github.com/LightwheelAI/LW-BenchHub) +- [LW-BenchHub Documentation](https://docs.lightwheel.net/lw_benchhub/) diff --git a/src/lerobot/configs/eval.py b/src/lerobot/configs/eval.py index 2f085da56..da8bee6b2 100644 --- a/src/lerobot/configs/eval.py +++ b/src/lerobot/configs/eval.py @@ -38,6 +38,8 @@ class EvalPipelineConfig: seed: int | None = 1000 # Rename map for the observation to override the image and state keys rename_map: dict[str, str] = field(default_factory=dict) + # Explicit consent to execute remote code from the Hub (required for hub environments). + trust_remote_code: bool = False def __post_init__(self) -> None: # HACK: We parse again the cli args here to get the pretrained path if there was one. diff --git a/src/lerobot/envs/__init__.py b/src/lerobot/envs/__init__.py index d767b6e8c..183c12325 100644 --- a/src/lerobot/envs/__init__.py +++ b/src/lerobot/envs/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .configs import AlohaEnv, EnvConfig, PushtEnv # noqa: F401 +from .configs import AlohaEnv, EnvConfig, HubEnvConfig, PushtEnv # noqa: F401 diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 4323f3316..112d3a73f 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -13,7 +13,7 @@ # limitations under the License. import abc -from dataclasses import dataclass, field +from dataclasses import dataclass, field, fields from typing import Any import draccus @@ -68,6 +68,22 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC): raise NotImplementedError() +@dataclass +class HubEnvConfig(EnvConfig): + """Base class for environments that delegate creation to a hub-hosted make_env. + + Hub environments download and execute remote code from the HF Hub. + The hub_path points to a repository containing an env.py with a make_env function. + """ + + hub_path: str | None = None # required: e.g., "username/repo" or "username/repo@branch:file.py" + + @property + def gym_kwargs(self) -> dict: + # Not used for hub environments - the hub's make_env handles everything + return {} + + @EnvConfig.register_subclass("aloha") @dataclass class AlohaEnv(EnvConfig): @@ -368,3 +384,71 @@ class MetaworldEnv(EnvConfig): "obs_type": self.obs_type, "render_mode": self.render_mode, } + + +@EnvConfig.register_subclass("isaaclab_arena") +@dataclass +class IsaaclabArenaEnv(HubEnvConfig): + hub_path: str = "nvidia/isaaclab-arena-envs" + episode_length: int = 300 + num_envs: int = 1 + embodiment: str | None = "gr1_pink" + object: str | None = "power_drill" + mimic: bool = False + teleop_device: str | None = None + seed: int | None = 42 + device: str | None = "cuda:0" + disable_fabric: bool = False + enable_cameras: bool = False + headless: bool = False + enable_pinocchio: bool = True + environment: str | None = "gr1_microwave" + task: str | None = "Reach out to the microwave and open it." + state_dim: int = 54 + action_dim: int = 36 + camera_height: int = 512 + camera_width: int = 512 + video: bool = False + video_length: int = 100 + video_interval: int = 200 + # Comma-separated keys, e.g., "robot_joint_pos,left_eef_pos" + state_keys: str = "robot_joint_pos" + # Comma-separated keys, e.g., "robot_pov_cam_rgb,front_cam_rgb" + # Set to None or "" for environments without cameras + camera_keys: str | None = None + features: dict[str, PolicyFeature] = field(default_factory=dict) + features_map: dict[str, str] = field(default_factory=dict) + kwargs: dict | None = None + + def __post_init__(self): + if self.kwargs: + # dynamically convert kwargs to fields in the dataclass + # NOTE! the new fields will not bee seen by the dataclass repr + field_names = {f.name for f in fields(self)} + for key, value in self.kwargs.items(): + if key not in field_names and key != "kwargs": + setattr(self, key, value) + self.kwargs = None + + # Set action feature + self.features[ACTION] = PolicyFeature(type=FeatureType.ACTION, shape=(self.action_dim,)) + self.features_map[ACTION] = ACTION + + # Set state feature + self.features[OBS_STATE] = PolicyFeature(type=FeatureType.STATE, shape=(self.state_dim,)) + self.features_map[OBS_STATE] = OBS_STATE + + # Add camera features for each camera key + if self.enable_cameras and self.camera_keys: + for cam_key in self.camera_keys.split(","): + cam_key = cam_key.strip() + if cam_key: + self.features[cam_key] = PolicyFeature( + type=FeatureType.VISUAL, + shape=(self.camera_height, self.camera_width, 3), + ) + self.features_map[cam_key] = f"{OBS_IMAGES}.{cam_key}" + + @property + def gym_kwargs(self) -> dict: + return {} diff --git a/src/lerobot/envs/factory.py b/src/lerobot/envs/factory.py index b39cfee71..1c59ccb7d 100644 --- a/src/lerobot/envs/factory.py +++ b/src/lerobot/envs/factory.py @@ -20,11 +20,11 @@ import gymnasium as gym from gymnasium.envs.registration import registry as gym_registry from lerobot.configs.policies import PreTrainedConfig -from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv +from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result from lerobot.policies.xvla.configuration_xvla import XVLAConfig from lerobot.processor import ProcessorStep -from lerobot.processor.env_processor import LiberoProcessorStep +from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep from lerobot.processor.pipeline import PolicyProcessorPipeline @@ -73,6 +73,26 @@ def make_env_pre_post_processors( if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type: preprocessor_steps.append(LiberoProcessorStep()) + # For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep + if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type: + # Parse comma-separated keys (handle None for state-based policies) + if env_cfg.state_keys: + state_keys = tuple(k.strip() for k in env_cfg.state_keys.split(",") if k.strip()) + else: + state_keys = () + if env_cfg.camera_keys: + camera_keys = tuple(k.strip() for k in env_cfg.camera_keys.split(",") if k.strip()) + else: + camera_keys = () + if not state_keys and not camera_keys: + raise ValueError("At least one of state_keys or camera_keys must be specified.") + preprocessor_steps.append( + IsaaclabArenaProcessorStep( + state_keys=state_keys, + camera_keys=camera_keys, + ) + ) + preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps) postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps) @@ -98,7 +118,6 @@ def make_env( hub_cache_dir (str | None): Optional cache path for downloaded hub files. trust_remote_code (bool): **Explicit consent** to execute remote code from the Hub. Default False — must be set to True to import/exec hub `env.py`. - Raises: ValueError: if n_envs < 1 ModuleNotFoundError: If the requested env package is not installed @@ -112,19 +131,35 @@ def make_env( """ # if user passed a hub id string (e.g., "username/repo", "username/repo@main:env.py") # simplified: only support hub-provided `make_env` + # TODO: (jadechoghari): deprecate string API and remove this check if isinstance(cfg, str): + hub_path: str | None = cfg + elif isinstance(cfg, HubEnvConfig): + hub_path = cfg.hub_path + else: + hub_path = None + + # If hub_path is set, download and call hub-provided `make_env` + if hub_path: # _download_hub_file will raise the same RuntimeError if trust_remote_code is False - repo_id, file_path, local_file, revision = _download_hub_file(cfg, trust_remote_code, hub_cache_dir) + repo_id, file_path, local_file, revision = _download_hub_file( + hub_path, trust_remote_code, hub_cache_dir + ) # import and surface clear import errors module = _import_hub_module(local_file, repo_id) # call the hub-provided make_env - raw_result = _call_make_env(module, n_envs=n_envs, use_async_envs=use_async_envs) + env_cfg = None if isinstance(cfg, str) else cfg + raw_result = _call_make_env(module, n_envs=n_envs, use_async_envs=use_async_envs, cfg=env_cfg) # normalize the return into {suite: {task_id: vec_env}} return _normalize_hub_result(raw_result) + # At this point, cfg must be an EnvConfig (not a string) since hub_path would have been set otherwise + if isinstance(cfg, str): + raise TypeError("cfg should be an EnvConfig at this point") + if n_envs < 1: raise ValueError("`n_envs` must be at least 1") diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index 8d0f24922..af814c92a 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -46,7 +46,7 @@ def _convert_nested_dict(d): def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]: - # TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding) + # TODO(jadechoghari, imstevenpmwork): refactor this to use features from the environment (no hardcoding) """Convert environment observation to LeRobot format observation. Args: observation: Dictionary of observation batches from a Gym vector environment. @@ -98,11 +98,19 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten if "robot_state" in observations: return_observations[f"{OBS_STR}.robot_state"] = _convert_nested_dict(observations["robot_state"]) + + # Handle IsaacLab Arena format: observations have 'policy' and 'camera_obs' keys + if "policy" in observations: + return_observations[f"{OBS_STR}.policy"] = observations["policy"] + + if "camera_obs" in observations: + return_observations[f"{OBS_STR}.camera_obs"] = observations["camera_obs"] + return return_observations def env_to_policy_features(env_cfg: EnvConfig) -> dict[str, PolicyFeature]: - # TODO(aliberts, rcadene): remove this hardcoding of keys and just use the nested keys as is + # TODO(jadechoghari, imstevenpmwork): remove this hardcoding of keys and just use the nested keys as is # (need to also refactor preprocess_observation and externalize normalization from policies) policy_features = {} for key, ft in env_cfg.features.items(): @@ -302,7 +310,7 @@ def _import_hub_module(local_file: str, repo_id: str) -> Any: return module -def _call_make_env(module: Any, n_envs: int, use_async_envs: bool) -> Any: +def _call_make_env(module: Any, n_envs: int, use_async_envs: bool, cfg: EnvConfig | None) -> Any: """ Ensure module exposes make_env and call it. """ @@ -311,7 +319,11 @@ def _call_make_env(module: Any, n_envs: int, use_async_envs: bool) -> Any: f"The hub module {getattr(module, '__name__', 'hub_module')} must expose `make_env(n_envs=int, use_async_envs=bool)`." ) entry_fn = module.make_env - return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs) + # Only pass cfg if it's not None (i.e., when an EnvConfig was provided, not a string hub ID) + if cfg is not None: + return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs, cfg=cfg) + else: + return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs) def _normalize_hub_result(result: Any) -> dict[str, dict[int, gym.vector.VectorEnv]]: diff --git a/src/lerobot/processor/env_processor.py b/src/lerobot/processor/env_processor.py index b1872b032..a5210af30 100644 --- a/src/lerobot/processor/env_processor.py +++ b/src/lerobot/processor/env_processor.py @@ -18,7 +18,7 @@ from dataclasses import dataclass import torch from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.utils.constants import OBS_IMAGES, OBS_STATE +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR from .pipeline import ObservationProcessorStep, ProcessorStepRegistry @@ -152,3 +152,78 @@ class LiberoProcessorStep(ObservationProcessorStep): result[mask] = axis * angle.unsqueeze(1) return result + + +@dataclass +@ProcessorStepRegistry.register(name="isaaclab_arena_processor") +class IsaaclabArenaProcessorStep(ObservationProcessorStep): + """ + Processes IsaacLab Arena observations into LeRobot format. + + **State Processing:** + - Extracts state components from obs["policy"] based on `state_keys`. + - Concatenates into a flat vector mapped to "observation.state". + + **Image Processing:** + - Extracts images from obs["camera_obs"] based on `camera_keys`. + - Converts from (B, H, W, C) uint8 to (B, C, H, W) float32 [0, 1]. + - Maps to "observation.images.". + """ + + # Configurable from IsaacLabEnv config / cli args: --env.state_keys="robot_joint_pos,left_eef_pos" + state_keys: tuple[str, ...] + + # Configurable from IsaacLabEnv config / cli args: --env.camera_keys="robot_pov_cam_rgb" + camera_keys: tuple[str, ...] + + def _process_observation(self, observation): + """ + Processes both image and policy state observations from IsaacLab Arena. + """ + processed_obs = {} + + if f"{OBS_STR}.camera_obs" in observation: + camera_obs = observation[f"{OBS_STR}.camera_obs"] + + for cam_name, img in camera_obs.items(): + if cam_name not in self.camera_keys: + continue + + img = img.permute(0, 3, 1, 2).contiguous() + if img.dtype == torch.uint8: + img = img.float() / 255.0 + elif img.dtype != torch.float32: + img = img.float() + + processed_obs[f"{OBS_IMAGES}.{cam_name}"] = img + + # Process policy state -> observation.state + if f"{OBS_STR}.policy" in observation: + policy_obs = observation[f"{OBS_STR}.policy"] + + # Collect state components in order + state_components = [] + for key in self.state_keys: + if key in policy_obs: + component = policy_obs[key] + # Flatten extra dims: (B, N, M) -> (B, N*M) + if component.dim() > 2: + batch_size = component.shape[0] + component = component.view(batch_size, -1) + state_components.append(component) + + if state_components: + state = torch.cat(state_components, dim=-1) + state = state.float() + processed_obs[OBS_STATE] = state + + return processed_obs + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """Not used for policy evaluation.""" + return features + + def observation(self, observation): + return self._process_observation(observation) diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index d23b9d083..9c40a1883 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -509,7 +509,12 @@ def eval_main(cfg: EvalPipelineConfig): logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") logging.info("Making environment.") - envs = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) + envs = make_env( + cfg.env, + n_envs=cfg.eval.batch_size, + use_async_envs=cfg.eval.use_async_envs, + trust_remote_code=cfg.trust_remote_code, + ) logging.info("Making policy.") diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 134228c05..58a83fe69 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -17,7 +17,7 @@ import json import tempfile from collections.abc import Callable -from dataclasses import dataclass +from dataclasses import dataclass, field from pathlib import Path from typing import Any @@ -1884,7 +1884,7 @@ class FeatureContractAddStep(ProcessorStep): """Adds a PolicyFeature""" key: str = "a" - value: PolicyFeature = PolicyFeature(type=FeatureType.STATE, shape=(1,)) + value: PolicyFeature = field(default_factory=lambda: PolicyFeature(type=FeatureType.STATE, shape=(1,))) def __call__(self, transition: EnvTransition) -> EnvTransition: return transition From 17c115c71f58cf4ce875dbb29beeb72f2bc322c0 Mon Sep 17 00:00:00 2001 From: Lior Ben Horin Date: Sun, 4 Jan 2026 17:41:21 +0200 Subject: [PATCH 011/111] IsaacLab Arena Integration documentation update (#2749) * wording * added how to guide to build you own envhub repos * include LW edits * wording * chat fixes * additional * wording * wording * wording * pre commit fixes --- docs/source/envhub_isaaclab_arena.mdx | 60 +++++++++++++++++++++------ 1 file changed, 48 insertions(+), 12 deletions(-) diff --git a/docs/source/envhub_isaaclab_arena.mdx b/docs/source/envhub_isaaclab_arena.mdx index 454a003a0..518def72c 100644 --- a/docs/source/envhub_isaaclab_arena.mdx +++ b/docs/source/envhub_isaaclab_arena.mdx @@ -12,11 +12,11 @@ Train and evaluate imitation learning policies with high-fidelity simulation — [IsaacLab Arena](https://github.com/isaac-sim/IsaacLab-Arena) integrates with NVIDIA IsaacLab to provide: - 🤖 **Humanoid embodiments**: GR1, G1, Galileo with various configurations -- 🎯 **Manipulation & loco-manipulation tasks**: Microwave opening, pick-and-place, button pressing +- 🎯 **Manipulation & loco-manipulation tasks**: Door opening, pick-and-place, button pressing, and more - ⚡ **GPU-accelerated rollouts**: Parallel environment execution on NVIDIA GPUs -- 🖼️ **RTX Rendering**: Evaluate vision-based policies with realistic rendering, reflections and refractions. -- 📦 **LeRobot-compatible datasets**: Ready for training with GR00T Nx, PI0, SmolVLA, ACT, Diffusion policies -- 🔄 **EnvHub integration**: Load environments from HuggingFace Hub with one line +- 🖼️ **RTX Rendering**: Evaluate vision-based policies with realistic rendering, reflections and refractions +- 📦 **LeRobot-compatible datasets**: Ready for training with GR00T N1x, PI0, SmolVLA, ACT, and Diffusion policies +- 🔄 **EnvHub integration**: Load environments from HuggingFace EnvHub with one line ## Installation @@ -85,7 +85,7 @@ The following trained policies are available: ```bash pip install -e ".[smolvla]" -pip install numpy==1.26.0 # revert to numpy version is 1.26 +pip install numpy==1.26.0 # revert numpy to version 1.26 ``` ```bash @@ -113,7 +113,7 @@ lerobot-eval \ ```bash pip install -e ".[pi]" -pip install numpy==1.26.0 # revert to numpy version is 1.26 +pip install numpy==1.26.0 # revert numpy to version 1.26 ``` PI0.5 requires disabling torch compile for evaluation: @@ -131,8 +131,8 @@ TORCH_COMPILE_DISABLE=1 TORCHINDUCTOR_DISABLE=1 lerobot-eval \ --env.headless=false \ --env.enable_cameras=true \ --env.video=true \ - --env.video_length 15 \ - --env.video_interval 15 \ + --env.video_length=15 \ + --env.video_interval=15 \ --env.state_keys=robot_joint_pos \ --env.camera_keys=robot_pov_cam_rgb \ --trust_remote_code=True \ @@ -189,7 +189,7 @@ outputs/eval/2026-01-02/14-38-01_isaaclab_arena_smolvla/videos/gr1_microwave_0/e ## Training Policies -To learn more about training policies with LeRobot, please refer to training documentation: +To learn more about training policies with LeRobot, please refer to the training documentation: - [SmolVLA](./smolvla) - [Pi0.5](./pi05) @@ -259,7 +259,7 @@ from lerobot.configs.eval import EvalPipelineConfig @parser.wrap() def main(cfg: EvalPipelineConfig): - """Run zero action rollout for IsaacLab Arena environment.""" + """Run random action rollout for IsaacLab Arena environment.""" logging.info(pformat(asdict(cfg))) from lerobot.envs.factory import make_env @@ -302,6 +302,36 @@ python test_env_load_arena.py \ --env.type=isaaclab_arena ``` +## Creating New Environments + +First create a new IsaacLab Arena environment by following the [IsaacLab Arena Documentation](https://isaac-sim.github.io/IsaacLab-Arena/release/0.1.1/index.html). + +Clone our EnvHub repo: + +```bash +git clone https://huggingface.co/nvidia/isaaclab-arena-envs +``` + +Modify the `example_envs.yaml` file based on your new environment. +[Upload](./envhub#step-3-upload-to-the-hub) your modified repo to HuggingFace EnvHub. + + + Your IsaacLab Arena environment code must be locally available during + evaluation. Users can clone your environment repository separately, or you can + bundle the environment code and assets directly in your EnvHub repo. + + +Then, when evaluating, use your new environment: + +```bash +lerobot-eval \ + --env.hub_path=/isaaclab-arena-envs \ + --env.environment= \ + ...other flags... +``` + +We look forward to your contributions! + ## Troubleshooting ### CUDA out of memory @@ -331,7 +361,7 @@ Enable cameras when running headless: ### Policy output dimension mismatch -E.g. ensure `action_dim` matches your policy: +Ensure `action_dim` matches your policy: ```bash --env.action_dim=36 @@ -353,7 +383,7 @@ sudo apt update && sudo apt install -y libglu1-mesa libxt6 ## LightWheel LW-BenchHub -[LightWheel AI](https://www.lightwheel.ai) are bringing `Lightwheel-Libero-Tasks` and `Lightwheel-RoboCasa-Tasks` with 268 tasks to the LeRobot ecosystem. +[LightWheel](https://www.lightwheel.ai) is bringing `Lightwheel-Libero-Tasks` and `Lightwheel-RoboCasa-Tasks` with 268 tasks to the LeRobot ecosystem. LW-BenchHub collects and generates large-scale datasets via teleoperation that comply with the LeRobot specification, enabling out-of-the-box training and evaluation workflows. With the unified interface provided by EnvHub, developers can quickly build end-to-end experimental pipelines. @@ -363,7 +393,13 @@ Assuming you followed the [Installation](#installation) steps, you can install L ```bash conda install pinocchio -c conda-forge -y +pip install numpy==1.26.0 # revert numpy to version 1.26 + +sudo apt-get install git-lfs && git lfs install + git clone https://github.com/LightwheelAI/lw_benchhub +git lfs pull # Ensure LFS files (e.g., .usd assets) are downloaded + cd lw_benchhub pip install -e . ``` From 75ab388ecd28537093c7989d49a36b4dbf7921a2 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Sun, 4 Jan 2026 17:24:56 +0100 Subject: [PATCH 012/111] chore(readme): update discord invitation link (#2750) --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 02652d1c9..35b28da87 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,7 @@ [![Status](https://img.shields.io/pypi/status/lerobot)](https://pypi.org/project/lerobot/) [![Version](https://img.shields.io/pypi/v/lerobot)](https://pypi.org/project/lerobot/) [![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-v2.1-ff69b4.svg)](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md) +[![Discord](https://img.shields.io/badge/Discord-Join_Us-5865F2?style=flat&logo=discord&logoColor=white)](https://discord.gg/q8Dzzpym3f) @@ -127,7 +128,7 @@ Learn how to implement your own simulation environment or benchmark and distribu ## Resources - **[Documentation](https://huggingface.co/docs/lerobot/index):** The complete guide to tutorials & API. -- **[Discord](https://discord.gg/3gxM6Avj):** Join the `LeRobot` server to discuss with the community. +- **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community. - **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments. - **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot. From e670ac5daf9b76cf5c18f632b89a8ff85c727ca6 Mon Sep 17 00:00:00 2001 From: githubnemo Date: Mon, 5 Jan 2026 08:51:26 +0100 Subject: [PATCH 013/111] Add basic PEFT support to train script + record module (#1411) * Add basic support for PEFT adapter methods This changes adds support for training policies with much less parameters by applying adapter methods such as LoRA on specific parts of the policies and therefore possibly higher learning rates / batch sizes. To make this as accessible as possible I thought it useful to provide defaults for `target_modules` and `modules_to_save`. Currently only SmolVLA has such defaults but when we agree that this change is useful I will set out to generate more such defaults. While the user can override these settings, they are expected to only change the peft_method, rank and init_type parameters. * Implement loading of PEFT adapters Loading a PEFT adapter is currently done by initializing a policy with default config and then applying the adapter on the resulting model. This has the obvious drawback that any configurations done during training are not applied in the adapted model. Currently the `use_peft` attribute of `PreTrainedConfig` is only set during loading to signal the following code that it has to deal with a PEFT adapter. However we could imagine a scenario where this is already set at training time and stored alongside the adapter. * Store policy config alongside PEFT checkpoint Before this change the PEFT-wrapped policy did not save the policy's config alongside the adapter config / weights which prevented us from changing the policy config. Now the policy config is saved both in full training and PEFT training. This change makes loading the PEFT policy adapter much easier as well. * Add default config for ACT * Support targets like `all-linear` * Formatting * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix failing tests * Remove PEFT compatibility changes in config We'll wait for the PEFT release that fixes this for good. * Remove `use_peft` parameter from training script Instead we make the PEFT config optional which has the same effect. * Log adapter config to WandB * Better documentation for CLI arguments * Don't unload & merge the PEFT model This can make things hard when using quantized layers (user expects quantized base layers with unquantized adapters for example, merging defaults to upcast the layers leading to higher memory). * Correct way of identifying when to save config * Add CLI end-to-end tests Currently there don't seem to be any way to test the CLI commands. Since this change mostly happens in those I thought it best to add a way to test these commands end-to-end. More integrated commands like `lerobot-record` need patching but standalone commands like training seem to work fine. * Update default targets Removed ACT since it doesn't make sense to fine-tune ACT without having it pretrained beforehand. SmolVLA and Pi0/0.5 are much more senseful targets. * Clean up loading code - Centralized instantiation of the PEFT wrapper in `make_policy` for inference (e.g. in `lerobot-record`) - Training a PEFT policy also sets `cfg.use_peft` so that all inference code loading the policy can rely on that attribute to identify if PEFT loading is needed - Modified RTC example to also include PEFT policies. Mostly because this is an example I'm currently exploring. * Make sure push_to_hub works Since PEFT only wraps `push_to_hub` and not `push_model_to_hub`, the reference to `self` in `policy.push_model_to_hub` is the unwrapped policy which, of course, doesn't know anything about PEFT. To make the upload process aware of PEFT, we pass the unwrapped policy down to `push_model_to_hub` as a kwarg. This is not ideal but I think it is the best way for now. * formatting * Warn when encountering from-scratch-training * Revamp pretrained model loading There were quite a few factors that convinced me that the status quo is able to load pretrained models from the PEFT adapter config but in fact that didn't work. This commit fixes the following things: - policies wrapped in PEFT will now have a `name_or_path` attribute containing the name or path of the pretrained model we're fine-tuning - we further assume that SmolVLA without `pretrained_path` and `load_vlm_weights==False` must be an user-side error - we assume that using PEFT on from-scratch-policies must be an user-side-error * Make it possible to unset policy features This is necessary to train pre-trained policies on new datasets so that the features are inferred from the new dataset and not from the pretrained policy. * Use correct loading for PEFT in RTC example * Make it possible to use PeftModels in eval * Add test checking that PEFT actually reduces params * Adapt state/action projections instead of full-finetuning There doesn't seem to be a benefit to fully fine-tune these layers over just adapting them, so we do that instead. * Disallow PEFT training on non-pretrained policies At first I thought it would make sense to have this feature in case you want to fine-tune a pre-trained section but in the end it makes more trouble than it's worth. It's still possible to allow this in the future when a concrete need arises. * Add basic documentation * Formatting * Add peft as extra dependency, mark tests Fast tests currently fail because of the missing dependency. * Fix pre-commit issues * Add walx <> peft conflict for uv * Exclude peft from pi install for now --------- Co-authored-by: nemo Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --- docs/source/_toctree.yml | 2 + docs/source/peft_training.mdx | 62 +++++++ examples/rtc/eval_with_real_robot.py | 13 +- pyproject.toml | 12 +- src/lerobot/configs/default.py | 28 ++++ src/lerobot/configs/policies.py | 16 +- src/lerobot/configs/train.py | 3 +- src/lerobot/policies/factory.py | 31 +++- src/lerobot/policies/pretrained.py | 10 +- src/lerobot/rl/wandb_utils.py | 27 ++- src/lerobot/scripts/lerobot_eval.py | 9 +- src/lerobot/scripts/lerobot_record.py | 2 + src/lerobot/scripts/lerobot_train.py | 96 ++++++++++- src/lerobot/utils/train_utils.py | 4 + tests/test_cli.py | 229 ++++++++++++++++++++++++++ tests/utils/test_train_utils.py | 14 ++ 16 files changed, 548 insertions(+), 10 deletions(-) create mode 100644 docs/source/peft_training.mdx create mode 100644 tests/test_cli.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index b59d15734..381e95dc4 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -19,6 +19,8 @@ title: Train RL in Simulation - local: multi_gpu_training title: Multi GPU training + - local: peft_training + title: Training with PEFT (e.g., LoRA) title: "Tutorials" - sections: - local: lerobot-dataset-v3 diff --git a/docs/source/peft_training.mdx b/docs/source/peft_training.mdx new file mode 100644 index 000000000..dd0b10075 --- /dev/null +++ b/docs/source/peft_training.mdx @@ -0,0 +1,62 @@ +# Parameter efficient fine-tuning with 🤗 PEFT + +[🤗 PEFT](https://github.com/huggingface/peft) (Parameter-Efficient Fine-Tuning) is a library for efficiently adapting +large pretrained models such as pre-trained policies (e.g., SmolVLA, π₀, ...) to new tasks without training all +of the model's parameters while yielding comparable performance. + +Install the `lerobot[peft]` optional package to enable PEFT support. + +To read about all the possible methods of adaption, please refer to the [🤗 PEFT docs](https://huggingface.co/docs/peft/index). + +## Training SmolVLA + +In this section we'll show you how to train a pre-trained SmolVLA policy with PEFT on the libero dataset. +For brevity we're only training on the `libero_spatial` subset. We will use `lerobot/smolvla_base` as the model +to parameter efficiently fine-tune: + +``` +lerobot-train \ + --policy.path=lerobot/smolvla_base \ + --policy.repo_id=your_hub_name/my_libero_smolvla \ + --dataset.repo_id=HuggingFaceVLA/libero \ + --policy.output_features=null \ + --policy.input_features=null \ + --policy.optimizer_lr=1e-3 \ + --policy.scheduler_decay_lr=1e-4 \ + --env.type=libero \ + --env.task=libero_spatial \ + --steps=100000 \ + --batch_size=32 \ + --peft.method_type=LORA \ + --peft.r=64 +``` + +Note the `--peft.method_type` parameter that let's you select which PEFT method to use. Here we use +[LoRA](https://huggingface.co/docs/peft/main/en/package_reference/lora) (Low-Rank Adapter) which is probably the most +popular fine-tuning method to date. Low-rank adaption means that we only fine-tune a matrix with comparably low rank +instead of the full weight matrix. This rank can be specified using the `--peft.r` parameter. The higher the rank +the closer you get to full fine-tuning + +There are more complex methods that have more parameters. These are not yet supported, feel free to raise an issue +if you want to see a specific PEFT method supported. + +By default, PEFT will target the `q_proj` and `v_proj` layers of the LM expert in SmolVLA. It will also target the +state and action projection matrices as they are most likely task-dependent. If you need to target different layers +you can use `--peft.target_modules` to specify which layers to target. You can refer to the respective PEFT method's +documentation to see what inputs are supported, (e.g., [LoRA's target_modules documentation](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.target_modules)). +Usually a list of suffixes or a regex are supported. For example, to target the MLPs of the `lm_expert` instead of +the `q` and `v` projections, use: + +``` +--peft.target_modules='(model\.vlm_with_expert\.lm_expert\..*\.(down|gate|up)_proj|.*\.(state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out))' +``` + +In case you need to fully fine-tune a layer instead of just adapting it, you can supply a list of layer suffixes +to the `--peft.full_training_modules` parameter: + +``` +--peft.full_training_modules=["state_proj"] +``` + +The learning rate and the scheduled target learning rate can usually be scaled by a factor of 10 compared to the +learning rate used for full fine-tuning (e.g., 1e-4 normal, so 1e-3 using LoRA). diff --git a/examples/rtc/eval_with_real_robot.py b/examples/rtc/eval_with_real_robot.py index 6f051485a..5f44649da 100644 --- a/examples/rtc/eval_with_real_robot.py +++ b/examples/rtc/eval_with_real_robot.py @@ -455,7 +455,18 @@ def demo_cli(cfg: RTCDemoConfig): if cfg.policy.type == "pi05" or cfg.policy.type == "pi0": config.compile_model = cfg.use_torch_compile - policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config) + if config.use_peft: + from peft import PeftConfig, PeftModel + + peft_pretrained_path = cfg.policy.pretrained_path + peft_config = PeftConfig.from_pretrained(peft_pretrained_path) + + policy = policy_class.from_pretrained( + pretrained_name_or_path=peft_config.base_model_name_or_path, config=config + ) + policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config) + else: + policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config) # Turn on RTC policy.config.rtc_config = cfg.rtc diff --git a/pyproject.toml b/pyproject.toml index 33586c5e4..980844149 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -146,6 +146,7 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpci # Features async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"] +peft = ["lerobot[transformers-dep]", "peft>=0.18.0"] # Development dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"] @@ -182,7 +183,8 @@ all = [ "lerobot[phone]", "lerobot[libero]", "lerobot[metaworld]", - "lerobot[sarm]" + "lerobot[sarm]", + "lerobot[peft]", ] [project.scripts] @@ -417,6 +419,10 @@ conflicts = [ { extra = "wallx" }, { extra = "libero" }, ], + [ + { extra = "wallx" }, + { extra = "peft" }, + ], [ { extra = "wallx" }, { extra = "all" }, @@ -450,6 +456,10 @@ conflicts = [ { extra = "pi" }, { extra = "libero" }, ], + [ + { extra = "pi" }, + { extra = "peft" }, + ], [ { extra = "pi" }, { extra = "all" }, diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index 630d63f1b..f613b5251 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -67,3 +67,31 @@ class EvalConfig: f"to increase the number of episodes to match the batch size (e.g. `eval.n_episodes={self.batch_size}`), " f"or lower the batch size (e.g. `eval.batch_size={self.n_episodes}`)." ) + + +@dataclass +class PeftConfig: + # PEFT offers many fine-tuning methods, layer adapters being the most common and currently also the most + # effective methods so we'll focus on those in this high-level config interface. + + # Either a string (module name suffix or 'all-linear'), a list of module name suffixes or a regular expression + # describing module names to target with the configured PEFT method. Some policies have a default value for this + # so that you don't *have* to choose which layers to adapt but it might still be worthwhile depending on your case. + target_modules: list[str] | str | None = None + + # Names/suffixes of modules to fully fine-tune and store alongside adapter weights. Useful for layers that are + # not part of a pre-trained model (e.g., action state projections). Depending on the policy this defaults to layers + # that are newly created in pre-trained policies. If you're fine-tuning an already trained policy you might want + # to set this to `[]`. Corresponds to PEFT's `modules_to_save`. + full_training_modules: list[str] | None = None + + # The PEFT (adapter) method to apply to the policy. Needs to be a valid PEFT type. + method_type: str = "LORA" + + # Adapter initialization method. Look at the specific PEFT adapter documentation for defaults. + init_type: str | None = None + + # We expect that all PEFT adapters are in some way doing rank-decomposition therefore this parameter specifies + # the rank used for the adapter. In general a higher rank means more trainable parameters and closer to full + # fine-tuning. + r: int = 16 diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index 0ecfa169b..7f326b70b 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -55,14 +55,18 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno n_obs_steps: int = 1 - input_features: dict[str, PolicyFeature] = field(default_factory=dict) - output_features: dict[str, PolicyFeature] = field(default_factory=dict) + # `input_features` can be set to None/null in order to infer those values from the dataset. + input_features: dict[str, PolicyFeature] | None = field(default_factory=dict) + output_features: dict[str, PolicyFeature] | None = field(default_factory=dict) device: str | None = None # e.g. "cuda", "cuda:0", "cpu", or "mps" # `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP, # automatic gradient scaling is used. use_amp: bool = False + # Whether the policy employed PEFT for training. + use_peft: bool = False + push_to_hub: bool = True # type: ignore[assignment] # TODO: use a different name to avoid override repo_id: str | None = None @@ -125,6 +129,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno @property def robot_state_feature(self) -> PolicyFeature | None: + if not self.input_features: + return None for ft_name, ft in self.input_features.items(): if ft.type is FeatureType.STATE and ft_name == OBS_STATE: return ft @@ -132,6 +138,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno @property def env_state_feature(self) -> PolicyFeature | None: + if not self.input_features: + return None for _, ft in self.input_features.items(): if ft.type is FeatureType.ENV: return ft @@ -139,10 +147,14 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno @property def image_features(self) -> dict[str, PolicyFeature]: + if not self.input_features: + return {} return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL} @property def action_feature(self) -> PolicyFeature | None: + if not self.output_features: + return None for ft_name, ft in self.output_features.items(): if ft.type is FeatureType.ACTION and ft_name == ACTION: return ft diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index cee9dfdf9..7a5eee77d 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -24,7 +24,7 @@ from huggingface_hub.errors import HfHubHTTPError from lerobot import envs from lerobot.configs import parser -from lerobot.configs.default import DatasetConfig, EvalConfig, WandBConfig +from lerobot.configs.default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig from lerobot.configs.policies import PreTrainedConfig from lerobot.optim import OptimizerConfig from lerobot.optim.schedulers import LRSchedulerConfig @@ -65,6 +65,7 @@ class TrainPipelineConfig(HubMixin): scheduler: LRSchedulerConfig | None = None eval: EvalConfig = field(default_factory=EvalConfig) wandb: WandBConfig = field(default_factory=WandBConfig) + peft: PeftConfig | None = None # RA-BC (Reward-Aligned Behavior Cloning) parameters use_rabc: bool = False # Enable reward-weighted training diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 3e24656fc..8c414f235 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -471,11 +471,40 @@ def make_policy( if ds_meta is not None: kwargs["dataset_meta"] = ds_meta - if cfg.pretrained_path: + if not cfg.pretrained_path and cfg.use_peft: + raise ValueError( + "Instantiating a policy with `use_peft=True` without a checkpoint is not supported since that requires " + "the PEFT config parameters to be set. For training with PEFT, see `lerobot_train.py` on how to do that." + ) + + if cfg.pretrained_path and not cfg.use_peft: # Load a pretrained policy and override the config if needed (for example, if there are inference-time # hyperparameters that we want to vary). kwargs["pretrained_name_or_path"] = cfg.pretrained_path policy = policy_cls.from_pretrained(**kwargs) + elif cfg.pretrained_path and cfg.use_peft: + # Load a pretrained PEFT model on top of the policy. The pretrained path points to the folder/repo + # of the adapter and the adapter's config contains the path to the base policy. So we need the + # adapter config first, then load the correct policy and then apply PEFT. + from peft import PeftConfig, PeftModel + + logging.info("Loading policy's PEFT adapter.") + + peft_pretrained_path = cfg.pretrained_path + peft_config = PeftConfig.from_pretrained(peft_pretrained_path) + + kwargs["pretrained_name_or_path"] = peft_config.base_model_name_or_path + if not kwargs["pretrained_name_or_path"]: + # This means that there's a bug or we trained a policy from scratch using PEFT. + # It is more likely that this is a bug so we'll raise an error. + raise ValueError( + "No pretrained model name found in adapter config. Can't instantiate the pre-trained policy on which " + "the adapter was trained." + ) + + policy = policy_cls.from_pretrained(**kwargs) + policy = PeftModel.from_pretrained(policy, peft_pretrained_path, config=peft_config) + else: # Make a fresh policy. policy = policy_cls(**kwargs) diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py index 3f5d89ec5..a1499d077 100644 --- a/src/lerobot/policies/pretrained.py +++ b/src/lerobot/policies/pretrained.py @@ -206,6 +206,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): def push_model_to_hub( self, cfg: TrainPipelineConfig, + peft_model=None, ): api = HfApi() repo_id = api.create_repo( @@ -216,7 +217,14 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): with TemporaryDirectory(ignore_cleanup_errors=True) as tmp: saved_path = Path(tmp) / repo_id - self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors + if peft_model is not None: + # Since PEFT just forwards calls to `push_model_to_hub`, `self` is not the PeftModel wrapper + # but the actual policy which is why we need the PEFT model passed to us to save the adapter. + # That also means that we need to store the policy config ourselves since PEFT can't. + peft_model.save_pretrained(saved_path) + self.config.save_pretrained(saved_path) + else: + self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors card = self.generate_model_card( cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags diff --git a/src/lerobot/rl/wandb_utils.py b/src/lerobot/rl/wandb_utils.py index 1537b3783..7b7f8a57b 100644 --- a/src/lerobot/rl/wandb_utils.py +++ b/src/lerobot/rl/wandb_utils.py @@ -112,7 +112,32 @@ class WandBLogger: artifact_name = f"{self._group}-{step_id}" artifact_name = get_safe_wandb_artifact_name(artifact_name) artifact = self._wandb.Artifact(artifact_name, type="model") - artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE) + pretrained_model_dir = checkpoint_dir / PRETRAINED_MODEL_DIR + + # Check if this is a PEFT model (has adapter files instead of model.safetensors) + adapter_model_file = pretrained_model_dir / "adapter_model.safetensors" + standard_model_file = pretrained_model_dir / SAFETENSORS_SINGLE_FILE + + if adapter_model_file.exists(): + # PEFT model: add adapter files and configs + artifact.add_file(adapter_model_file) + adapter_config_file = pretrained_model_dir / "adapter_config.json" + if adapter_config_file.exists(): + artifact.add_file(adapter_config_file) + # Also add the policy config which is needed for loading + config_file = pretrained_model_dir / "config.json" + if config_file.exists(): + artifact.add_file(config_file) + elif standard_model_file.exists(): + # Standard model: add the single safetensors file + artifact.add_file(standard_model_file) + else: + logging.warning( + f"No {SAFETENSORS_SINGLE_FILE} or adapter_model.safetensors found in {pretrained_model_dir}. " + "Skipping model artifact upload to WandB." + ) + return + self._wandb.log_artifact(artifact) def log_dict( diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index 9c40a1883..7b45e88e1 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -278,9 +278,16 @@ def eval_policy( raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.") if not isinstance(policy, PreTrainedPolicy): - raise ValueError( + exc = ValueError( f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided." ) + try: + from peft import PeftModel + + if not isinstance(policy, PeftModel): + raise exc + except ImportError: + raise exc from None start = time.time() policy.eval() diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 948e92bb8..eafd32ace 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -193,8 +193,10 @@ class RecordConfig: def __post_init__(self): # HACK: We parse again the cli args here to get the pretrained path if there was one. policy_path = parser.get_path_arg("policy") + if policy_path: cli_overrides = parser.get_cli_overrides("policy") + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) self.policy.pretrained_path = policy_path diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 6cf733442..286c69906 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses import logging import time from contextlib import nullcontext @@ -147,6 +148,92 @@ def update_policy( return train_metrics, output_dict +def get_default_peft_configuration(policy_type): + """Build a basic PEFT configuration for the given policy type assuming that we train a policy from a checkpoint.""" + + common_projections = "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out" + + if policy_type == "smolvla": + return { + "target_modules": rf"(model\.vlm_with_expert\.lm_expert\..*\.(q|v)_proj|model\.({common_projections}))", + "modules_to_save": [], + } + elif policy_type in ("pi0", "pi05"): + return { + "target_modules": rf"(.*\.gemma_expert\..*\.self_attn.(q|v)_proj|model\.({common_projections}))", + "modules_to_save": [], + } + + return {"modules_to_save": None} + + +def wrap_policy_in_peft_model(cfg, policy): + from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType, get_peft_model + + # Disable all gradients because we'll only train the parameters selected by the PEFT method. + # Layers that should receive gradients anyway need to be listed in `modules_to_save`. + for p in policy.parameters(): + p.requires_grad_(False) + + if not cfg.policy.pretrained_path: + raise ValueError( + "Training from scratch using PEFT. This is unlikely to yield good results. " + "Supply a `policy.path` to fine-tune an existing model." + ) + + if cfg.policy.type == "smolvla" and not cfg.policy.load_vlm_weights: + logging.warning( + "Training SmolVLA from scratch using PEFT. This is unlikely to yield good results. Set " + "`load_vlm_weights=True` to fine-tune the existing policy." + ) + + peft_config_policy = get_default_peft_configuration(cfg.policy.type) + peft_config_cli = dataclasses.asdict(cfg.peft) if cfg.peft else {} + peft_config_cli["modules_to_save"] = peft_config_cli["full_training_modules"] # compatibility with PEFT + peft_method_type = PeftType[peft_config_cli["method_type"].upper()] + peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type] + + # Handle specific CLI overrides + for key in ["target_modules", "modules_to_save", "r"]: + if peft_config_cli[key] is not None: + peft_config_policy[key] = peft_config_cli[key] + + if "target_modules" not in peft_config_policy: + raise ValueError( + f"There is no default `target_modules` value for policy {cfg.policy.type}. Please pass it manually." + ) + + # Init method depends on the used PEFT method, your specific PEFT method + # might not be considered here, in that case an error is raised. + if peft_config_cli["init_type"] is not None: + if peft_method_type == "LORA": + peft_config_policy["init_lora_weights"] = peft_config_cli["init_type"] + elif peft_method_type == "MISS": + peft_config_policy["init_weights"] = peft_config_cli["init_type"] + else: + raise ValueError( + f"Init type {peft_config_cli['init_type']} unknown for PEFT method {peft_method_type}." + ) + + # PEFT uses this attribute to set adapter_config.base_name_or_path which we use for loading the + # correct base model in `make_policy` since in a PEFT loading setting we only get the path to the + # adapter, not the base model. + if policy.config.pretrained_path: + policy.name_or_path = str(policy.config.pretrained_path) + + # Finally wrap the policy in a PEFT model + policy = get_peft_model( + policy, + peft_config_cls(**peft_config_policy), + ) + + # Make sure that the config is tagged as using PEFT so that the loading code can take the + # appropriate steps to use the adapter weights and the PEFT config instead of the full model weights. + policy.config.use_peft = True + + return policy + + @parser.wrap() def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): """ @@ -230,6 +317,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): rename_map=cfg.rename_map, ) + if cfg.peft is not None: + logging.info("Using PEFT! Wrapping model.") + policy = wrap_policy_in_peft_model(cfg, policy) + # Wait for all processes to finish policy creation before continuing accelerator.wait_for_everyone() @@ -502,7 +593,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): if cfg.policy.push_to_hub: unwrapped_policy = accelerator.unwrap_model(policy) - unwrapped_policy.push_model_to_hub(cfg) + if cfg.policy.use_peft: + unwrapped_policy.push_model_to_hub(cfg, peft_model=unwrapped_policy) + else: + unwrapped_policy.push_model_to_hub(cfg) preprocessor.push_to_hub(cfg.policy.repo_id) postprocessor.push_to_hub(cfg.policy.repo_id) diff --git a/src/lerobot/utils/train_utils.py b/src/lerobot/utils/train_utils.py index 3ebe31971..d8481f4b9 100644 --- a/src/lerobot/utils/train_utils.py +++ b/src/lerobot/utils/train_utils.py @@ -99,6 +99,10 @@ def save_checkpoint( pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR policy.save_pretrained(pretrained_dir) cfg.save_pretrained(pretrained_dir) + if cfg.peft is not None: + # When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the + # policy config which we need for loading the model. In this case we'll write it ourselves. + policy.config.save_pretrained(pretrained_dir) if preprocessor is not None: preprocessor.save_pretrained(pretrained_dir) if postprocessor is not None: diff --git a/tests/test_cli.py b/tests/test_cli.py new file mode 100644 index 000000000..ddaae0c9b --- /dev/null +++ b/tests/test_cli.py @@ -0,0 +1,229 @@ +import importlib +import os +from unittest.mock import MagicMock, patch + +import pytest +from safetensors.torch import load_file + +from .utils import require_package + + +def run_command(cmd, module, args): + module = importlib.import_module(f"lerobot.scripts.{module}") + with patch("sys.argv", [cmd] + args): + module.main() + + +def lerobot_train(args): + return run_command(cmd="lerobot-train", module="lerobot_train", args=args) + + +def lerobot_record(args): + return run_command(cmd="lerobot-record", module="lerobot_record", args=args) + + +def resolve_model_id_for_peft_training(policy_type): + """PEFT training needs pretrained models, this finds the pretrained model of a policy type for PEFT training.""" + if policy_type == "smolvla": + return "lerobot/smolvla_base" + + raise ValueError(f"No pretrained model known for {policy_type}. PEFT training will not work.") + + +@pytest.mark.parametrize("policy_type", ["smolvla"]) +@require_package("peft") +def test_peft_training_push_to_hub_works(policy_type, tmp_path): + """Ensure that push to hub stores PEFT only the adapter, not the full model weights.""" + output_dir = tmp_path / f"output_{policy_type}" + upload_folder_contents = set() + + model_id = resolve_model_id_for_peft_training(policy_type) + + def mock_upload_folder(*args, **kwargs): + folder_path = kwargs["folder_path"] + # we include more than is actually uploaded since we ignore {allow,ignore}_patterns of upload_folders() + upload_folder_contents.update(os.listdir(folder_path)) + return MagicMock() + + with ( + patch("huggingface_hub.HfApi.create_repo"), + patch("huggingface_hub.HfApi.upload_folder", mock_upload_folder), + ): + lerobot_train( + [ + f"--policy.path={model_id}", + "--policy.push_to_hub=true", + "--policy.repo_id=foo/bar", + "--policy.input_features=null", + "--policy.output_features=null", + "--peft.method=LORA", + "--dataset.repo_id=lerobot/pusht", + "--dataset.episodes=[0, 1]", + "--steps=1", + f"--output_dir={output_dir}", + ] + ) + + assert "adapter_model.safetensors" in upload_folder_contents + assert "config.json" in upload_folder_contents + assert "adapter_config.json" in upload_folder_contents + + +@pytest.mark.parametrize("policy_type", ["smolvla"]) +@require_package("peft") +def test_peft_training_works(policy_type, tmp_path): + """Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works.""" + output_dir = tmp_path / f"output_{policy_type}" + model_id = resolve_model_id_for_peft_training(policy_type) + + lerobot_train( + [ + f"--policy.path={model_id}", + "--policy.push_to_hub=false", + "--policy.input_features=null", + "--policy.output_features=null", + "--peft.method=LORA", + "--dataset.repo_id=lerobot/pusht", + "--dataset.episodes=[0, 1]", + "--steps=1", + f"--output_dir={output_dir}", + ] + ) + + policy_dir = output_dir / "checkpoints" / "last" / "pretrained_model" + + for file in ["adapter_config.json", "adapter_model.safetensors", "config.json"]: + assert (policy_dir / file).exists() + + # This is the default case where we train a pre-trained policy from scratch with new data. + # We assume that we target policy-specific modules but fully fine-tune action and state projections + # so these must be part of the trained state dict. + state_dict = load_file(policy_dir / "adapter_model.safetensors") + + adapted_keys = [ + "state_proj", + "action_in_proj", + "action_out_proj", + "action_time_mlp_in", + "action_time_mlp_out", + ] + + found_keys = [ + module_key + for module_key in adapted_keys + for state_dict_key in state_dict + if f".{module_key}." in state_dict_key + ] + + assert set(found_keys) == set(adapted_keys) + + +@pytest.mark.parametrize("policy_type", ["smolvla"]) +@require_package("peft") +def test_peft_training_params_are_fewer(policy_type, tmp_path): + """Check whether the standard case of fine-tuning a (partially) pre-trained policy with PEFT works.""" + output_dir = tmp_path / f"output_{policy_type}" + model_id = resolve_model_id_for_peft_training(policy_type) + + def dummy_update_policy( + train_metrics, policy, batch, optimizer, grad_clip_norm: float, accelerator, **kwargs + ): + params_total = sum(p.numel() for p in policy.parameters()) + params_trainable = sum(p.numel() for p in policy.parameters() if p.requires_grad) + + assert params_total > params_trainable + + return train_metrics, {} + + with patch("lerobot.scripts.lerobot_train.update_policy", dummy_update_policy): + lerobot_train( + [ + f"--policy.path={model_id}", + "--policy.push_to_hub=false", + "--policy.input_features=null", + "--policy.output_features=null", + "--peft.method=LORA", + "--dataset.repo_id=lerobot/pusht", + "--dataset.episodes=[0, 1]", + "--steps=1", + f"--output_dir={output_dir}", + ] + ) + + +class DummyRobot: + name = "dummy" + cameras = [] + action_features = {"foo": 1.0, "bar": 2.0} + observation_features = {"obs1": 1.0, "obs2": 2.0} + is_connected = True + + def connect(self, *args): + pass + + def disconnect(self): + pass + + +def dummy_make_robot_from_config(*args, **kwargs): + return DummyRobot() + + +@pytest.mark.parametrize("policy_type", ["smolvla"]) +@require_package("peft") +def test_peft_record_loads_policy(policy_type, tmp_path): + """Train a policy with PEFT and attempt to load it with `lerobot-record`.""" + from peft import PeftModel + + output_dir = tmp_path / f"output_{policy_type}" + model_id = resolve_model_id_for_peft_training(policy_type) + + lerobot_train( + [ + f"--policy.path={model_id}", + "--policy.push_to_hub=false", + "--policy.input_features=null", + "--policy.output_features=null", + "--peft.method=LORA", + "--dataset.repo_id=lerobot/pusht", + "--dataset.episodes=[0, 1]", + "--steps=1", + f"--output_dir={output_dir}", + ] + ) + + policy_dir = output_dir / "checkpoints" / "last" / "pretrained_model" + dataset_dir = tmp_path / "eval_pusht" + single_task = "move the table" + loaded_policy = None + + def dummy_record_loop(*args, **kwargs): + nonlocal loaded_policy + + if "dataset" not in kwargs: + return + + dataset = kwargs["dataset"] + dataset.add_frame({"task": single_task}) + loaded_policy = kwargs["policy"] + + with ( + patch("lerobot.robots.make_robot_from_config", dummy_make_robot_from_config), + # disable record loop since we're only interested in successful loading of the policy. + patch("lerobot.scripts.lerobot_record.record_loop", dummy_record_loop), + # disable speech output + patch("lerobot.utils.utils.say"), + ): + lerobot_record( + [ + f"--policy.path={policy_dir}", + "--robot.type=so101_follower", + "--robot.port=/dev/null", + "--dataset.repo_id=lerobot/eval_pusht", + f'--dataset.single_task="{single_task}"', + f"--dataset.root={dataset_dir}", + "--dataset.push_to_hub=false", + ] + ) + + assert isinstance(loaded_policy, PeftModel) diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py index 892503e97..4791caf58 100644 --- a/tests/utils/test_train_utils.py +++ b/tests/utils/test_train_utils.py @@ -82,6 +82,20 @@ def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer): mock_save_training_state.assert_called_once() +@patch("lerobot.utils.train_utils.save_training_state") +def test_save_checkpoint_peft(mock_save_training_state, tmp_path, optimizer): + policy = Mock() + policy.config = Mock() + policy.config.save_pretrained = Mock() + cfg = Mock() + cfg.use_peft = True + save_checkpoint(tmp_path, 10, cfg, policy, optimizer) + policy.save_pretrained.assert_called_once() + cfg.save_pretrained.assert_called_once() + policy.config.save_pretrained.assert_called_once() + mock_save_training_state.assert_called_once() + + def test_save_training_state(tmp_path, optimizer, scheduler): save_training_state(tmp_path, 10, optimizer, scheduler) assert (tmp_path / TRAINING_STATE_DIR).is_dir() From 6106a8136c62fad75905f1eefc087f2fdaa6b10c Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Mon, 5 Jan 2026 12:13:42 +0100 Subject: [PATCH 014/111] Fix invalid syntax (#2752) * fix invalid syntax * also skip for torchdiffeq * fix patch for gpu tests --- tests/policies/wall_x/test_wallx.py | 5 +++-- tests/test_cli.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/policies/wall_x/test_wallx.py b/tests/policies/wall_x/test_wallx.py index 837907041..e5f124123 100644 --- a/tests/policies/wall_x/test_wallx.py +++ b/tests/policies/wall_x/test_wallx.py @@ -21,9 +21,10 @@ import os import pytest import torch -# Skip if openpi or transformers is not available +# Skip if required dependencies are not available pytest.importorskip("peft") -pytest.importorskip("transformers==4.49.0") +pytest.importorskip("transformers") +pytest.importorskip("torchdiffeq") # Skip this entire module in CI pytestmark = pytest.mark.skipif( diff --git a/tests/test_cli.py b/tests/test_cli.py index ddaae0c9b..a3472b48e 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -208,7 +208,7 @@ def test_peft_record_loads_policy(policy_type, tmp_path): loaded_policy = kwargs["policy"] with ( - patch("lerobot.robots.make_robot_from_config", dummy_make_robot_from_config), + patch("lerobot.scripts.lerobot_record.make_robot_from_config", dummy_make_robot_from_config), # disable record loop since we're only interested in successful loading of the policy. patch("lerobot.scripts.lerobot_record.record_loop", dummy_record_loop), # disable speech output From 603d44434f432c79765c6e18c290fccac1bd7b3e Mon Sep 17 00:00:00 2001 From: Tong Wu <54630004+wut19@users.noreply.github.com> Date: Tue, 6 Jan 2026 22:13:35 +0800 Subject: [PATCH 015/111] fix a bug for kwargs in wallx (#2714) * support wallx * fix bugs in flow * incorporate wallx model into lerobot * update the policy methods * reduce to least config and params & pass lerobot basic test * fixed dtype bugs * add wallx dependencies * update * remove flash-attn requirement && fix bug in inference and fast mode * fix bug for inference * add some small modifications * fix pre-commit errors * remove lerobot[wallx] * fix ci * fix precommit issues * fix: exclude wallx extra properly in CI workflows * fix: add uv conflicts for wallx transformers version * fix: peft test import * pre-commit * only export WallXConfig from wall_x package to avoid peft import in CI * remove torch dep * precommit * add import * update doc files * fix minor errors * fix a bug for kwargs * fix precommit issue --------- Signed-off-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: vincentchen Co-authored-by: Geoffrey19 Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Pepijn Co-authored-by: geoffrey --- docs/source/policy_walloss_README.md | 18 ++++++++++++++---- src/lerobot/policies/wall_x/modeling_wall_x.py | 2 +- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/docs/source/policy_walloss_README.md b/docs/source/policy_walloss_README.md index 78548bd8d..93c0ad392 100644 --- a/docs/source/policy_walloss_README.md +++ b/docs/source/policy_walloss_README.md @@ -1,20 +1,30 @@ # WALL-OSS -This repository contains the Hugging Face port of **WALL-OSS**, a Vision-Language-Action model for cross-embodiment robotic control based on Qwen2.5-VL with flow matching/FAST action prediction. +This repository contains the Hugging Face port of [**WALL-OSS**](https://x2robot.com/en/research/68bc2cde8497d7f238dde690), a Vision-Language-Action model for cross-embodiment robotic control based on Qwen2.5-VL with flow matching/FAST action prediction. --- ## Model Overview | Feature | Description | -| ------------------ | ----------------------------------------------------- | --- | +| ------------------ | ----------------------------------------------------- | | Base Model | Qwen2.5-VL (Vision-Language Model) | | Action Prediction | Flow Matching (diffusion) or FAST (discrete tokens) | -| Architecture | Mixture of Experts (MoE) with action-specific routing | | +| Architecture | Mixture of Experts (MoE) with action-specific routing | | Multi-Modal Inputs | Vision (images/videos), Language, Proprioception | --- +## Additional Resources + +Paper: https://arxiv.org/pdf/2509.11766 + +Official Repository: https://github.com/X-Square-Robot/wall-x + +Hugging Face: https://huggingface.co/x-square-robot + +--- + ## Citation If you use this work, please cite: @@ -32,4 +42,4 @@ If you use this work, please cite: ## License -This port follows the **Apache 2.0 License**. +This model follows the **Apache 2.0 License**, consistent with the original [WallX repository](https://github.com/X-Square-Robot/wall-x). diff --git a/src/lerobot/policies/wall_x/modeling_wall_x.py b/src/lerobot/policies/wall_x/modeling_wall_x.py index c401c8d60..94ee7897e 100644 --- a/src/lerobot/policies/wall_x/modeling_wall_x.py +++ b/src/lerobot/policies/wall_x/modeling_wall_x.py @@ -1697,7 +1697,7 @@ class WallXPolicy(PreTrainedPolicy): config_class = WallXConfig name = "wall_x" - def __init__(self, config: WallXConfig): + def __init__(self, config: WallXConfig, **kwargs): super().__init__(config) config.validate_features() self.config = config From 963a3482fac60c5af3f26188507d0055c9a0362a Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Tue, 6 Jan 2026 18:17:29 +0100 Subject: [PATCH 016/111] typo LW (#2758) --- docs/source/envhub_isaaclab_arena.mdx | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/envhub_isaaclab_arena.mdx b/docs/source/envhub_isaaclab_arena.mdx index 518def72c..828d51bad 100644 --- a/docs/source/envhub_isaaclab_arena.mdx +++ b/docs/source/envhub_isaaclab_arena.mdx @@ -381,9 +381,9 @@ sudo apt update && sudo apt install -y libglu1-mesa libxt6 - [IsaacLab Arena GitHub](https://github.com/isaac-sim/IsaacLab-Arena) - [IsaacLab Documentation](https://isaac-sim.github.io/IsaacLab/) -## LightWheel LW-BenchHub +## Lightwheel LW-BenchHub -[LightWheel](https://www.lightwheel.ai) is bringing `Lightwheel-Libero-Tasks` and `Lightwheel-RoboCasa-Tasks` with 268 tasks to the LeRobot ecosystem. +[Lightwheel](https://www.lightwheel.ai) is bringing `Lightwheel-Libero-Tasks` and `Lightwheel-RoboCasa-Tasks` with 268 tasks to the LeRobot ecosystem. LW-BenchHub collects and generates large-scale datasets via teleoperation that comply with the LeRobot specification, enabling out-of-the-box training and evaluation workflows. With the unified interface provided by EnvHub, developers can quickly build end-to-end experimental pipelines. From e2957d778397fd70f7200fb69763238dc1aaa132 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 6 Jan 2026 20:09:43 +0100 Subject: [PATCH 017/111] fix: precise_sleep is never called with negative value (#2757) --- docs/source/envhub_leisaac.mdx | 2 +- docs/source/il_robots.mdx | 2 +- examples/backward_compatibility/replay.py | 2 +- examples/phone_to_so100/replay.py | 2 +- examples/so100_to_so100_EE/replay.py | 2 +- src/lerobot/rl/actor.py | 2 +- src/lerobot/rl/gym_manipulator.py | 6 +++--- src/lerobot/scripts/lerobot_record.py | 2 +- src/lerobot/scripts/lerobot_replay.py | 2 +- src/lerobot/scripts/lerobot_teleoperate.py | 2 +- 10 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/source/envhub_leisaac.mdx b/docs/source/envhub_leisaac.mdx index ff848d415..5cf4a0e45 100644 --- a/docs/source/envhub_leisaac.mdx +++ b/docs/source/envhub_leisaac.mdx @@ -196,7 +196,7 @@ def teleop_loop(teleop: Teleoperator, env: gym.Env, fps: int): obs, info = env.reset() dt_s = time.perf_counter() - loop_start - precise_sleep(1 / fps - dt_s) + precise_sleep(max(1 / fps - dt_s, 0.0)) loop_s = time.perf_counter() - loop_start print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)") diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 0bc1ca681..eb779d2b1 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -432,7 +432,7 @@ for idx in range(dataset.num_frames): } robot.send_action(action) - precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0)) + precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) robot.disconnect() ``` diff --git a/examples/backward_compatibility/replay.py b/examples/backward_compatibility/replay.py index ed52a24c9..85f3ecef7 100644 --- a/examples/backward_compatibility/replay.py +++ b/examples/backward_compatibility/replay.py @@ -97,7 +97,7 @@ def replay(cfg: ReplayConfig): robot.send_action(action) dt_s = time.perf_counter() - start_episode_t - precise_sleep(1 / dataset.fps - dt_s) + precise_sleep(max(1 / dataset.fps - dt_s, 0.0)) robot.disconnect() diff --git a/examples/phone_to_so100/replay.py b/examples/phone_to_so100/replay.py index a7b18a53c..ab6ce3ced 100644 --- a/examples/phone_to_so100/replay.py +++ b/examples/phone_to_so100/replay.py @@ -96,7 +96,7 @@ def main(): # Send action to robot _ = robot.send_action(joint_action) - precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0)) + precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) # Clean up robot.disconnect() diff --git a/examples/so100_to_so100_EE/replay.py b/examples/so100_to_so100_EE/replay.py index 9951b139d..59c524078 100644 --- a/examples/so100_to_so100_EE/replay.py +++ b/examples/so100_to_so100_EE/replay.py @@ -97,7 +97,7 @@ def main(): # Send action to robot _ = robot.send_action(joint_action) - precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0)) + precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) # Clean up robot.disconnect() diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index 13fd66507..641a21d03 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -398,7 +398,7 @@ def act_with_policy( if cfg.env.fps is not None: dt_time = time.perf_counter() - start_time - precise_sleep(1 / cfg.env.fps - dt_time) + precise_sleep(max(1 / cfg.env.fps - dt_time, 0.0)) # Communication Functions - Group all gRPC/messaging functions diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index ad1fdf55f..7bc74a959 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -238,7 +238,7 @@ class RobotEnv(gym.Env): reset_follower_position(self.robot, np.array(self.reset_pose)) log_say("Reset the environment done.", play_sounds=True) - precise_sleep(self.reset_time_s - (time.perf_counter() - start_time)) + precise_sleep(max(self.reset_time_s - (time.perf_counter() - start_time), 0.0)) super().reset(seed=seed, options=options) @@ -713,7 +713,7 @@ def control_loop( transition = env_processor(transition) # Maintain fps timing - precise_sleep(dt - (time.perf_counter() - step_start_time)) + precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0)) if dataset is not None and cfg.dataset.push_to_hub: logging.info("Pushing dataset to hub") @@ -745,7 +745,7 @@ def replay_trajectory( ) transition = action_processor(transition) env.step(transition[TransitionKey.ACTION]) - precise_sleep(1 / cfg.env.fps - (time.perf_counter() - start_time)) + precise_sleep(max(1 / cfg.env.fps - (time.perf_counter() - start_time), 0.0)) @parser.wrap() diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index eafd32ace..62bf21653 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -374,7 +374,7 @@ def record_loop( log_rerun_data(observation=obs_processed, action=action_values) dt_s = time.perf_counter() - start_loop_t - precise_sleep(1 / fps - dt_s) + precise_sleep(max(1 / fps - dt_s, 0.0)) timestamp = time.perf_counter() - start_episode_t diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index d5808c768..5cde4251c 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -123,7 +123,7 @@ def replay(cfg: ReplayConfig): _ = robot.send_action(processed_action) dt_s = time.perf_counter() - start_episode_t - precise_sleep(1 / dataset.fps - dt_s) + precise_sleep(max(1 / dataset.fps - dt_s, 0.0)) robot.disconnect() diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index bf722d6f1..9866e7629 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -177,7 +177,7 @@ def teleop_loop( move_cursor_up(len(robot_action_to_send) + 3) dt_s = time.perf_counter() - loop_start - precise_sleep(1 / fps - dt_s) + precise_sleep(max(1 / fps - dt_s, 0.0)) loop_s = time.perf_counter() - loop_start print(f"Teleop loop time: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)") move_cursor_up(1) From a9d81e7f67b848c767660689721c24a504d6957f Mon Sep 17 00:00:00 2001 From: Pauline Bailly-Masson <155966238+paulinebm@users.noreply.github.com> Date: Wed, 7 Jan 2026 00:21:03 +0100 Subject: [PATCH 018/111] refactor(ci): Docker Hub image env (#2755) * Refactor Docker Hub image env Updated environment variable usage for Docker Hub credentials and corrected image tag extraction. Signed-off-by: Pauline Bailly-Masson <155966238+paulinebm@users.noreply.github.com> * same Signed-off-by: Pauline Bailly-Masson <155966238+paulinebm@users.noreply.github.com> * Apply suggestions from code review Signed-off-by: Steven Palma * chore(ci): remove duplicated IMAGE_FULL variable definition --------- Signed-off-by: Pauline Bailly-Masson <155966238+paulinebm@users.noreply.github.com> Signed-off-by: Steven Palma Co-authored-by: Steven Palma --- .github/workflows/full_tests.yml | 13 ++++++++----- .github/workflows/unbound_deps_tests.yml | 12 ++++++++---- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml index 7bf2cb78c..4dce3121a 100644 --- a/.github/workflows/full_tests.yml +++ b/.github/workflows/full_tests.yml @@ -186,15 +186,18 @@ jobs: steps: - name: Get Docker Hub Token and Delete Image # zizmor: ignore[template-injection] + env: + DOCKERHUB_LEROBOT_USERNAME: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }} + DOCKERHUB_LEROBOT_PASSWORD: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }} + IMAGE_FULL: ${{ needs.build-and-push-docker.outputs.image_tag }} run: | - IMAGE_NAME=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f1) - IMAGE_TAG=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f2) - + IMAGE_NAME=$(echo "$IMAGE_FULL" | cut -d':' -f1) + IMAGE_TAG=$(echo "$IMAGE_FULL" | cut -d':' -f2-) echo "Attempting to delete image: $IMAGE_NAME:$IMAGE_TAG" TOKEN=$(curl -s -H "Content-Type: application/json" \ -X POST \ - -d '{"username": "${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}", "password": "${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}"}' \ + -d "{\"username\": \"$DOCKERHUB_LEROBOT_USERNAME\", \"password\": \"$DOCKERHUB_LEROBOT_PASSWORD\"}" \ https://hub.docker.com/v2/users/login/ | jq -r .token) if [ "$TOKEN" == "null" ] || [ -z "$TOKEN" ]; then @@ -205,7 +208,7 @@ jobs: HTTP_RESPONSE=$(curl -s -o /dev/null -w "%{http_code}" \ -H "Authorization: JWT ${TOKEN}" \ -X DELETE \ - https://hub.docker.com/v2/repositories/${IMAGE_NAME}/tags/${IMAGE_TAG}/) + https://hub.docker.com/v2/repositories/${IMAGE_NAME}/tags/$IMAGE_TAG) if [ "$HTTP_RESPONSE" -eq 204 ]; then echo "Successfully deleted Docker image tag: $IMAGE_NAME:$IMAGE_TAG" diff --git a/.github/workflows/unbound_deps_tests.yml b/.github/workflows/unbound_deps_tests.yml index 3908bdc3d..e3ae71cc9 100644 --- a/.github/workflows/unbound_deps_tests.yml +++ b/.github/workflows/unbound_deps_tests.yml @@ -162,15 +162,19 @@ jobs: steps: - name: Get Docker Hub Token and Delete Image # zizmor: ignore[template-injection] + env: + DOCKERHUB_LEROBOT_USERNAME: ${{ secrets.DOCKERHUB_LEROBOT_USERNAME }} + DOCKERHUB_LEROBOT_PASSWORD: ${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }} + IMAGE_FULL: ${{ needs.build-and-push-docker.outputs.image_tag }} run: | - IMAGE_NAME=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f1) - IMAGE_TAG=$(echo "${{ needs.build-and-push-docker.outputs.image_tag }}" | cut -d':' -f2) + IMAGE_NAME=$(echo "$IMAGE_FULL" | cut -d':' -f1) + IMAGE_TAG=$(echo "$IMAGE_FULL" | cut -d':' -f2) echo "Attempting to delete image: $IMAGE_NAME:$IMAGE_TAG" TOKEN=$(curl -s -H "Content-Type: application/json" \ -X POST \ - -d '{"username": "${{ secrets.DOCKERHUB_LEROBOT_USERNAME }}", "password": "${{ secrets.DOCKERHUB_LEROBOT_PASSWORD }}"}' \ + -d "{\"username\": \"$DOCKERHUB_LEROBOT_USERNAME\", \"password\": \"$DOCKERHUB_LEROBOT_PASSWORD\"}" \ https://hub.docker.com/v2/users/login/ | jq -r .token) if [ "$TOKEN" == "null" ] || [ -z "$TOKEN" ]; then @@ -181,7 +185,7 @@ jobs: HTTP_RESPONSE=$(curl -s -o /dev/null -w "%{http_code}" \ -H "Authorization: JWT ${TOKEN}" \ -X DELETE \ - https://hub.docker.com/v2/repositories/${IMAGE_NAME}/tags/${IMAGE_TAG}/) + https://hub.docker.com/v2/repositories/${IMAGE_NAME}/tags/$IMAGE_TAG) if [ "$HTTP_RESPONSE" -eq 204 ]; then echo "Successfully deleted Docker image tag: $IMAGE_NAME:$IMAGE_TAG" From ecd8cd23d22a66aaa15099939a7aa9653aa866e2 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 7 Jan 2026 11:04:21 +0100 Subject: [PATCH 019/111] chore(dependencies): bound new dependecies (#2759) --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 980844149..f25645319 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -109,7 +109,7 @@ hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"] lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"] unitree_g1 = [ "pyzmq>=26.2.1,<28.0.0", - "onnxruntime>=1.16.0" + "onnxruntime>=1.16.0,<2.0.0" ] reachy2 = ["reachy2_sdk>=1.0.14,<1.1.0"] kinematics = ["lerobot[placo-dep]"] @@ -140,13 +140,13 @@ groot = [ "ninja>=1.11.1,<2.0.0", "flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'" ] -sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14"] +sarm = ["lerobot[transformers-dep]", "faker>=33.0.0,<35.0.0", "matplotlib>=3.10.3,<4.0.0", "qwen-vl-utils>=0.0.14,<0.1.0"] xvla = ["lerobot[transformers-dep]"] hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"] # Features async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3,<4.0.0"] -peft = ["lerobot[transformers-dep]", "peft>=0.18.0"] +peft = ["lerobot[transformers-dep]", "peft>=0.18.0,<1.0.0"] # Development dev = ["pre-commit>=3.7.0,<5.0.0", "debugpy>=1.8.1,<1.9.0", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1", "mypy>=1.19.1"] From 7e9d05a799175dce80e8ef1f8db23667ceeeb266 Mon Sep 17 00:00:00 2001 From: Martino Russi <77496684+nepyope@users.noreply.github.com> Date: Wed, 7 Jan 2026 16:05:31 +0100 Subject: [PATCH 020/111] add holosoma locomotion (#2669) Add holosoma locomotion from Amazon-FAR Add reset method to unitree_g1 Format actions as dict Update docs --- docs/source/unitree_g1.mdx | 55 ++-- examples/unitree_g1/gr00t_locomotion.py | 250 ++++++----------- examples/unitree_g1/holosoma_locomotion.py | 264 ++++++++++++++++++ .../robots/unitree_g1/config_unitree_g1.py | 10 +- src/lerobot/robots/unitree_g1/g1_utils.py | 8 - src/lerobot/robots/unitree_g1/unitree_g1.py | 59 +++- 6 files changed, 431 insertions(+), 215 deletions(-) create mode 100644 examples/unitree_g1/holosoma_locomotion.py diff --git a/docs/source/unitree_g1.mdx b/docs/source/unitree_g1.mdx index af06fd742..bdc7eb33d 100644 --- a/docs/source/unitree_g1.mdx +++ b/docs/source/unitree_g1.mdx @@ -1,21 +1,21 @@ -# Unitree G1 Robot Setup and Control +# Unitree G1 This guide covers the complete setup process for the Unitree G1 humanoid, from initial connection to running gr00t_wbc locomotion. -## About the Unitree G1 +## About -We offer support for both 29 and 23 DOF G1. We introduce: +We support both 29 and 23 DOF G1 EDU version. We introduce: -- **`unitree g1` robot class, handling low level communication with the humanoid** -- **ZMQ socket bridge** for remote communication over WiFi, allowing one to deploy policies remotely instead of over ethernet or directly on the Orin -- **GR00T locomotion policy** for bipedal walking and balance -- **MuJoCo simulation mode** for testing policies without the physical robot +- **`unitree g1` robot class, handling low level read/write from/to the humanoid** +- **ZMQ socket bridge** for remote communication over wlan, allowing for remote policy deployment as well as over eth or directly on the Orin +- **Locomotion policies** from NVIDIA gr00t and Amazon FAR Holosoma +- **Simulation mode** for testing policies without the physical robot in mujoco --- -## Part 1: Connect to Robot over Ethernet +## Connection guide -### Step 1: Configure Your Computer's Ethernet Interface +### Step 1: Configure Ethernet Interface Set a static IP on the same subnet as the robot: @@ -26,7 +26,7 @@ sudo ip addr add 192.168.123.200/24 dev enp131s0 sudo ip link set enp131s0 up ``` -**Note**: The robot's Ethernet IP is fixed at `192.168.123.164`. Your computer must use `192.168.123.x` where x ≠ 164. +**Note**: The G1's Ethernet IP is fixed at `192.168.123.164`. Your computer must use `192.168.123.x` with x ≠ 164. ### Step 2: SSH into the Robot @@ -35,25 +35,24 @@ ssh unitree@192.168.123.164 # Password: 123 ``` -You should now be connected to the robot's onboard computer. +You should now be connected to the G1's Orin. --- ## Part 2: Enable WiFi on the Robot -Once connected via Ethernet, follow these steps to enable WiFi: +Wlan0 is disabled by default on the G1. To enable it: ### Step 1: Enable WiFi Hardware ```bash -# Unblock WiFi radio sudo rfkill unblock wifi sudo rfkill unblock all -# Bring up WiFi interface +# Bring up wlan0 sudo ip link set wlan0 up -# Enable NetworkManager control +# Enable NetworkManager control of wlan0 sudo nmcli radio wifi on sudo nmcli device set wlan0 managed yes sudo systemctl restart NetworkManager @@ -73,7 +72,7 @@ sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTA sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT ``` -**On the robot:** +**On the G1:** ```bash # Add laptop as default gateway @@ -147,9 +146,9 @@ python src/lerobot/robots/unitree_g1/run_g1_server.py --- -## Part 4: Running GR00T Locomotion +## Part 4: Controlling the robot -With the robot server running, you can now control the robot from your laptop. +With the robot server running, you can now control the robot remotely. Let's launch a locomotion policy ### Step 1: Install LeRobot on your machine @@ -172,34 +171,30 @@ Edit the config file to match your robot's WiFi IP: robot_ip: str = "" # Replace with your robot's WiFi IP. ``` -**Note**: When running directly on the G1 (not remotely), set `robot_ip: str = "127.0.0.1"` instead. - ### Step 3: Run the Locomotion Policy ```bash # Run GR00T locomotion controller python examples/unitree_g1/gr00t_locomotion.py --repo-id "nepyope/GR00T-WholeBodyControl_g1" + +# Run Holosoma locomotion controller +python examples/unitree_g1/holosoma_locomotion.py + ``` -### Step 4: Control with Remote - -- **Left stick**: Forward/backward and left/right movement -- **Right stick**: Rotation -- **R1 button**: Raise waist height -- **R2 button**: Lower waist height - Press `Ctrl+C` to stop the policy. --- -## Extra: Running in Simulation Mode (MuJoCo) +## Running in Simulation Mode (MuJoCo) -You can now test and develop policies without a physical robot using MuJoCo. to do so set `is_simulation=True` in config. +You can now test and develop policies without a physical robot using MuJoCo. To do so simply set `is_simulation=True` in config. ## Additional Resources - [Unitree SDK Documentation](https://github.com/unitreerobotics/unitree_sdk2_python) -- [GR00T Policy Repository](https://huggingface.co/nepyope/GR00T-WholeBodyControl_g1) +- [GR00T-WholeBodyControl](https://github.com/NVlabs/GR00T-WholeBodyControl) +- [Holosoma](https://github.com/amazon-far/holosoma) - [LeRobot Documentation](https://github.com/huggingface/lerobot) - [Unitree_IL_Lerobot](https://github.com/unitreerobotics/unitree_IL_lerobot) diff --git a/examples/unitree_g1/gr00t_locomotion.py b/examples/unitree_g1/gr00t_locomotion.py index 7cc4e03be..30e5a27e6 100644 --- a/examples/unitree_g1/gr00t_locomotion.py +++ b/examples/unitree_g1/gr00t_locomotion.py @@ -13,16 +13,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" -Example: GR00T Locomotion with Pre-loaded Policies - -This example demonstrates the NEW pattern for loading GR00T policies externally -and passing them to the robot class. -""" import argparse import logging -import threading import time from collections import deque @@ -31,24 +24,26 @@ import onnxruntime as ort from huggingface_hub import hf_hub_download from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config +from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1 +logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + GROOT_DEFAULT_ANGLES = np.zeros(29, dtype=np.float32) -GROOT_DEFAULT_ANGLES[[0, 6]] = -0.1 # hip pitch -GROOT_DEFAULT_ANGLES[[3, 9]] = 0.3 # knee -GROOT_DEFAULT_ANGLES[[4, 10]] = -0.2 # ankle pitch +GROOT_DEFAULT_ANGLES[[0, 6]] = -0.1 # Hip pitch +GROOT_DEFAULT_ANGLES[[3, 9]] = 0.3 # Knee +GROOT_DEFAULT_ANGLES[[4, 10]] = -0.2 # Ankle pitch MISSING_JOINTS = [] -G1_MODEL = "g1_23" # or "g1_29" +G1_MODEL = "g1_23" # Or "g1_29" if G1_MODEL == "g1_23": - MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # waist yaw/pitch, wrist pitch/yaw - -LOCOMOTION_ACTION_SCALE = 0.25 - -LOCOMOTION_CONTROL_DT = 0.02 + MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # Waist yaw/pitch, wrist pitch/yaw +# Control parameters +ACTION_SCALE = 0.25 +CONTROL_DT = 0.02 # 50Hz ANG_VEL_SCALE: float = 0.25 DOF_POS_SCALE: float = 1.0 DOF_VEL_SCALE: float = 0.05 @@ -61,12 +56,12 @@ DEFAULT_GROOT_REPO_ID = "nepyope/GR00T-WholeBodyControl_g1" def load_groot_policies( repo_id: str = DEFAULT_GROOT_REPO_ID, ) -> tuple[ort.InferenceSession, ort.InferenceSession]: - """Load GR00T dual-policy system (Balance + Walk) from Hugging Face Hub. + """Load GR00T dual-policy system (Balance + Walk) from the hub. Args: repo_id: Hugging Face Hub repository ID containing the ONNX policies. """ - logger.info(f"Loading GR00T dual-policy system from Hugging Face Hub ({repo_id})...") + logger.info(f"Loading GR00T dual-policy system from the hub ({repo_id})...") # Download ONNX policies from Hugging Face Hub balance_path = hf_hub_download( @@ -88,15 +83,7 @@ def load_groot_policies( class GrootLocomotionController: - """ - Handles GR00T-style locomotion control for the Unitree G1 robot. - - This controller manages: - - Dual-policy system (Balance + Walk) - - 29-joint observation processing - - 15D action output (legs + waist) - - Policy inference and motor command generation - """ + """GR00T lower-body locomotion controller for the Unitree G1.""" def __init__(self, policy_balance, policy_walk, robot, config): self.policy_balance = policy_balance @@ -104,9 +91,9 @@ class GrootLocomotionController: self.robot = robot self.config = config - self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, theta_dot + self.cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, theta_dot - # GR00T-specific state + # Robot state self.groot_qj_all = np.zeros(29, dtype=np.float32) self.groot_dqj_all = np.zeros(29, dtype=np.float32) self.groot_action = np.zeros(15, dtype=np.float32) @@ -116,24 +103,20 @@ class GrootLocomotionController: self.groot_height_cmd = 0.74 # Default base height self.groot_orientation_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) - # input to gr00t is 6 frames (6*86D=516) + # Input to GR00T is 6 frames (6*86D=516) for _ in range(6): self.groot_obs_history.append(np.zeros(86, dtype=np.float32)) - # Thread management - self.locomotion_running = False - self.locomotion_thread = None - logger.info("GrootLocomotionController initialized") - def groot_locomotion_run(self): - # get current observation + def run_step(self): + # Get current observation robot_state = self.robot.get_observation() if robot_state is None: return - # get command from remote controller + # Get command from remote controller if robot_state.wireless_remote is not None: self.robot.remote_controller.set(robot_state.wireless_remote) if self.robot.remote_controller.button[0]: # R1 - raise waist @@ -148,15 +131,16 @@ class GrootLocomotionController: self.robot.remote_controller.rx = 0.0 self.robot.remote_controller.ry = 0.0 - self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward - self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right - self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # rotation rate + self.cmd[0] = self.robot.remote_controller.ly # Forward/backward + self.cmd[1] = self.robot.remote_controller.lx * -1 # Left/right + self.cmd[2] = self.robot.remote_controller.rx * -1 # Rotation rate + # Get joint positions and velocities for i in range(29): self.groot_qj_all[i] = robot_state.motor_state[i].q self.groot_dqj_all[i] = robot_state.motor_state[i].dq - # adapt observation for g1_23dof + # Adapt observation for g1_23dof for idx in MISSING_JOINTS: self.groot_qj_all[idx] = 0.0 self.groot_dqj_all[idx] = 0.0 @@ -165,18 +149,18 @@ class GrootLocomotionController: qj_obs = self.groot_qj_all.copy() dqj_obs = self.groot_dqj_all.copy() - # express imu data in gravity frame of reference + # Express IMU data in gravity frame of reference quat = robot_state.imu_state.quaternion ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32) gravity_orientation = self.robot.get_gravity_orientation(quat) - # scale joint positions and velocities before policy inference + # Scale joint positions and velocities before policy inference qj_obs = (qj_obs - GROOT_DEFAULT_ANGLES) * DOF_POS_SCALE dqj_obs = dqj_obs * DOF_VEL_SCALE ang_vel_scaled = ang_vel * ANG_VEL_SCALE - # build single frame observation - self.groot_obs_single[:3] = self.locomotion_cmd * np.array(CMD_SCALE) + # Build single frame observation + self.groot_obs_single[:3] = self.cmd * np.array(CMD_SCALE) self.groot_obs_single[3] = self.groot_height_cmd self.groot_obs_single[4:7] = self.groot_orientation_cmd self.groot_obs_single[7:10] = ang_vel_scaled @@ -194,113 +178,74 @@ class GrootLocomotionController: end_idx = start_idx + 86 self.groot_obs_stacked[start_idx:end_idx] = obs_frame - # Run policy inference (ONNX) with 516D stacked observation - - cmd_magnitude = np.linalg.norm(self.locomotion_cmd) - + cmd_magnitude = np.linalg.norm(self.cmd) selected_policy = ( self.policy_balance if cmd_magnitude < 0.05 else self.policy_walk - ) # balance/standing policy for small commands, walking policy for movement commands + ) # Balance/standing policy for small commands, walking policy for movement commands - # run policy inference + # Run policy inference ort_inputs = {selected_policy.get_inputs()[0].name: np.expand_dims(self.groot_obs_stacked, axis=0)} ort_outs = selected_policy.run(None, ort_inputs) self.groot_action = ort_outs[0].squeeze() - # transform action back to target joint positions - target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * LOCOMOTION_ACTION_SCALE + # Transform action back to target joint positions + target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * ACTION_SCALE - # command motors + # Build action dict (only first 15 joints for GR00T) + action_dict = {} for i in range(15): - motor_idx = i - self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos_15[i] - self.robot.msg.motor_cmd[motor_idx].qd = 0 - self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx] - self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx] - self.robot.msg.motor_cmd[motor_idx].tau = 0 + motor_name = G1_29_JointIndex(i).name + action_dict[f"{motor_name}.q"] = float(target_dof_pos_15[i]) - # adapt action for g1_23dof + # Zero out missing joints for g1_23dof for joint_idx in MISSING_JOINTS: - self.robot.msg.motor_cmd[joint_idx].q = 0.0 - self.robot.msg.motor_cmd[joint_idx].qd = 0 - self.robot.msg.motor_cmd[joint_idx].kp = self.robot.kp[joint_idx] - self.robot.msg.motor_cmd[joint_idx].kd = self.robot.kd[joint_idx] - self.robot.msg.motor_cmd[joint_idx].tau = 0 + motor_name = G1_29_JointIndex(joint_idx).name + action_dict[f"{motor_name}.q"] = 0.0 - # send action to robot - self.robot.send_action(self.robot.msg) + # Send action to robot + self.robot.send_action(action_dict) - def _locomotion_thread_loop(self): - """Background thread that runs the locomotion policy at specified rate.""" - logger.info("Locomotion thread started") - while self.locomotion_running: + +def run(repo_id: str = DEFAULT_GROOT_REPO_ID) -> None: + """Main function to run the GR00T locomotion controller. + + Args: + repo_id: Hugging Face Hub repository ID for GR00T policies. + """ + # Load policies + policy_balance, policy_walk = load_groot_policies(repo_id=repo_id) + + # Initialize robot + config = UnitreeG1Config() + robot = UnitreeG1(config) + + # Initialize gr00T locomotion controller + groot_controller = GrootLocomotionController( + policy_balance=policy_balance, + policy_walk=policy_walk, + robot=robot, + config=config, + ) + + try: + robot.reset(CONTROL_DT, GROOT_DEFAULT_ANGLES) + + logger.info("Use joystick: LY=fwd/back, LX=left/right, RX=rotate, R1=raise waist, R2=lower waist") + logger.info("Press Ctrl+C to stop") + + # Run step + while True: start_time = time.time() - try: - self.groot_locomotion_run() - except Exception as e: - logger.error(f"Error in locomotion loop: {e}") - - # Sleep to maintain control rate + groot_controller.run_step() elapsed = time.time() - start_time - sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed) + sleep_time = max(0, CONTROL_DT - elapsed) time.sleep(sleep_time) - logger.info("Locomotion thread stopped") - - def start_locomotion_thread(self): - if self.locomotion_running: - logger.warning("Locomotion thread already running") - return - - logger.info("Starting locomotion control thread...") - self.locomotion_running = True - self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True) - self.locomotion_thread.start() - - logger.info("Locomotion control thread started!") - - def stop_locomotion_thread(self): - if not self.locomotion_running: - return - - logger.info("Stopping locomotion control thread...") - self.locomotion_running = False - if self.locomotion_thread: - self.locomotion_thread.join(timeout=2.0) - logger.info("Locomotion control thread stopped") - - def reset_robot(self): - """Move robot legs to default standing position over 2 seconds (arms are not moved).""" - total_time = 3.0 - num_step = int(total_time / self.robot.control_dt) - - # Only control legs, not arms (first 12 joints) - default_pos = GROOT_DEFAULT_ANGLES # First 12 values are leg angles - dof_size = len(default_pos) - - # Get current lowstate - robot_state = self.robot.get_observation() - - # Record the current leg positions - init_dof_pos = np.zeros(dof_size, dtype=np.float32) - for i in range(dof_size): - init_dof_pos[i] = robot_state.motor_state[i].q - - # Move legs to default pos - for i in range(num_step): - alpha = i / num_step - for motor_idx in range(dof_size): - target_pos = default_pos[motor_idx] - self.robot.msg.motor_cmd[motor_idx].q = ( - init_dof_pos[motor_idx] * (1 - alpha) + target_pos * alpha - ) - self.robot.msg.motor_cmd[motor_idx].qd = 0 - self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx] - self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx] - self.robot.msg.motor_cmd[motor_idx].tau = 0 - self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg) - self.robot.lowcmd_publisher.Write(self.robot.msg) - time.sleep(self.robot.control_dt) - logger.info("Reached default position (legs only)") + except KeyboardInterrupt: + logger.info("Stopping locomotion...") + finally: + if robot.is_connected: + robot.disconnect() + logger.info("Done!") if __name__ == "__main__": @@ -313,35 +258,4 @@ if __name__ == "__main__": ) args = parser.parse_args() - # load policies - policy_balance, policy_walk = load_groot_policies(repo_id=args.repo_id) - - # initialize robot - config = UnitreeG1Config() - robot = UnitreeG1(config) - - # initialize gr00t locomotion controller - groot_controller = GrootLocomotionController( - policy_balance=policy_balance, - policy_walk=policy_walk, - robot=robot, - config=config, - ) - - # reset legs and start locomotion thread - try: - groot_controller.reset_robot() - groot_controller.start_locomotion_thread() - - # log status - logger.info("Robot initialized with GR00T locomotion policies") - logger.info("Locomotion controller running in background thread") - logger.info("Press Ctrl+C to stop") - - # keep robot alive - while True: - time.sleep(1.0) - except KeyboardInterrupt: - print("\nStopping locomotion...") - groot_controller.stop_locomotion_thread() - print("Done!") + run(repo_id=args.repo_id) diff --git a/examples/unitree_g1/holosoma_locomotion.py b/examples/unitree_g1/holosoma_locomotion.py new file mode 100644 index 000000000..017f7835a --- /dev/null +++ b/examples/unitree_g1/holosoma_locomotion.py @@ -0,0 +1,264 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import logging +import time + +import numpy as np +import onnx +import onnxruntime as ort +from huggingface_hub import hf_hub_download + +from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config +from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex +from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1 + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +DEFAULT_ANGLES = np.zeros(29, dtype=np.float32) +DEFAULT_ANGLES[[0, 6]] = -0.312 # Hip pitch +DEFAULT_ANGLES[[3, 9]] = 0.669 # Knee +DEFAULT_ANGLES[[4, 10]] = -0.363 # Ankle pitch +DEFAULT_ANGLES[[15, 22]] = 0.2 # Shoulder pitch +DEFAULT_ANGLES[16] = 0.2 # Left shoulder roll +DEFAULT_ANGLES[23] = -0.2 # Right shoulder roll +DEFAULT_ANGLES[[18, 25]] = 0.6 # Elbow + +MISSING_JOINTS = [] +G1_MODEL = "g1_23" # Or "g1_29" +if G1_MODEL == "g1_23": + MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # Waist yaw/pitch, wrist pitch/yaw + +# Control parameters +ACTION_SCALE = 0.25 +CONTROL_DT = 0.02 # 50Hz +ANG_VEL_SCALE = 0.25 +DOF_POS_SCALE = 1.0 +DOF_VEL_SCALE = 0.05 +GAIT_PERIOD = 1.0 + + +DEFAULT_HOLOSOMA_REPO_ID = "nepyope/holosoma_locomotion" + +# Policy filename mapping +POLICY_FILES = { + "fastsac": "fastsac_g1_29dof.onnx", + "ppo": "ppo_g1_29dof.onnx", +} + + +def load_policy( + repo_id: str = DEFAULT_HOLOSOMA_REPO_ID, + policy_type: str = "fastsac", +) -> tuple[ort.InferenceSession, np.ndarray, np.ndarray]: + """Load Holosoma locomotion policy and extract KP/KD from metadata. + + Args: + repo_id: Hugging Face Hub repo ID + policy_type: Either "fastsac" (default) or "ppo" + + Returns: + (policy, kp, kd) tuple + """ + if policy_type not in POLICY_FILES: + raise ValueError(f"Unknown policy type: {policy_type}. Choose from: {list(POLICY_FILES.keys())}") + + filename = POLICY_FILES[policy_type] + logger.info(f"Loading {policy_type.upper()} policy from: {repo_id}/{filename}") + policy_path = hf_hub_download(repo_id=repo_id, filename=filename) + + policy = ort.InferenceSession(policy_path) + logger.info(f"Policy loaded: {policy.get_inputs()[0].shape} → {policy.get_outputs()[0].shape}") + + # Extract KP/KD from ONNX metadata + model = onnx.load(policy_path) + metadata = {prop.key: prop.value for prop in model.metadata_props} + + if "kp" not in metadata or "kd" not in metadata: + raise ValueError("ONNX model must contain 'kp' and 'kd' in metadata") + + kp = np.array(json.loads(metadata["kp"]), dtype=np.float32) + kd = np.array(json.loads(metadata["kd"]), dtype=np.float32) + logger.info(f"Loaded KP/KD from ONNX ({len(kp)} joints)") + + return policy, kp, kd + + +class HolosomaLocomotionController: + """Holosoma whole-body locomotion controller for Unitree G1.""" + + def __init__(self, policy, robot, kp: np.ndarray, kd: np.ndarray): + self.policy = policy + self.robot = robot + + # Override robot's PD gains with policy gains + self.robot.kp = kp + self.robot.kd = kd + + self.cmd = np.zeros(3, dtype=np.float32) + + # Robot state + self.qj = np.zeros(29, dtype=np.float32) + self.dqj = np.zeros(29, dtype=np.float32) + self.obs = np.zeros(100, dtype=np.float32) + self.last_action = np.zeros(29, dtype=np.float32) + + # Gait phase + self.phase = np.array([[0.0, np.pi]], dtype=np.float32) + self.phase_dt = 2 * np.pi / ((1.0 / CONTROL_DT) * GAIT_PERIOD) + self.is_standing = True + + def run_step(self): + # Get current observation + robot_state = self.robot.get_observation() + + if robot_state is None: + return + + # Get command from remote controller + if robot_state.wireless_remote is not None: + self.robot.remote_controller.set(robot_state.wireless_remote) + + ly = self.robot.remote_controller.ly if abs(self.robot.remote_controller.ly) > 0.1 else 0.0 + lx = self.robot.remote_controller.lx if abs(self.robot.remote_controller.lx) > 0.1 else 0.0 + rx = self.robot.remote_controller.rx if abs(self.robot.remote_controller.rx) > 0.1 else 0.0 + self.cmd[:] = [ly, -lx, -rx] + + # Get joint positions and velocities + for i in range(29): + self.qj[i] = robot_state.motor_state[i].q + self.dqj[i] = robot_state.motor_state[i].dq + + # Adapt observation for g1_23dof + for idx in MISSING_JOINTS: + self.qj[idx] = 0.0 + self.dqj[idx] = 0.0 + + # Express IMU data in gravity frame of reference + quat = robot_state.imu_state.quaternion + ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32) + gravity = self.robot.get_gravity_orientation(quat) + + # Scale joint positions and velocities before policy inference + qj_obs = (self.qj - DEFAULT_ANGLES) * DOF_POS_SCALE + dqj_obs = self.dqj * DOF_VEL_SCALE + ang_vel_s = ang_vel * ANG_VEL_SCALE + + # Update gait phase + if np.linalg.norm(self.cmd[:2]) < 0.01 and abs(self.cmd[2]) < 0.01: + self.phase[0, :] = np.pi + self.is_standing = True + elif self.is_standing: + self.phase = np.array([[0.0, np.pi]], dtype=np.float32) + self.is_standing = False + else: + self.phase = np.fmod(self.phase + self.phase_dt + np.pi, 2 * np.pi) - np.pi + + sin_ph = np.sin(self.phase[0]) + cos_ph = np.cos(self.phase[0]) + + # Build observations + self.obs[0:29] = self.last_action + self.obs[29:32] = ang_vel_s + self.obs[32] = self.cmd[2] + self.obs[33:35] = self.cmd[:2] + self.obs[35:37] = cos_ph + self.obs[37:66] = qj_obs + self.obs[66:95] = dqj_obs + self.obs[95:98] = gravity + self.obs[98:100] = sin_ph + + # Run policy inference + ort_in = {self.policy.get_inputs()[0].name: self.obs.reshape(1, -1).astype(np.float32)} + raw_action = self.policy.run(None, ort_in)[0].squeeze() + action = np.clip(raw_action, -100.0, 100.0) + self.last_action = action.copy() + + # Transform action back to target joint positions + target = DEFAULT_ANGLES + action * ACTION_SCALE + + # Build action dict + action_dict = {} + for motor in G1_29_JointIndex: + action_dict[f"{motor.name}.q"] = float(target[motor.value]) + + # Zero out missing joints for g1_23dof + for joint_idx in MISSING_JOINTS: + motor_name = G1_29_JointIndex(joint_idx).name + action_dict[f"{motor_name}.q"] = 0.0 + + # Send action to robot + self.robot.send_action(action_dict) + + +def run(repo_id: str = DEFAULT_HOLOSOMA_REPO_ID, policy_type: str = "fastsac") -> None: + """Main function to run the Holosoma locomotion controller. + + Args: + repo_id: Hugging Face Hub repository ID for Holosoma policies. + policy_type: Policy type to use ('fastsac' or 'ppo'). + """ + # Load policy and gains + policy, kp, kd = load_policy(repo_id=repo_id, policy_type=policy_type) + + # Initialize robot + config = UnitreeG1Config() + robot = UnitreeG1(config) + + holosoma_controller = HolosomaLocomotionController(policy, robot, kp, kd) + + try: + robot.reset(CONTROL_DT, DEFAULT_ANGLES) + + logger.info("Use joystick: LY=fwd/back, LX=left/right, RX=rotate") + logger.info("Press Ctrl+C to stop") + + # Run step + while True: + start_time = time.time() + holosoma_controller.run_step() + elapsed = time.time() - start_time + sleep_time = max(0, CONTROL_DT - elapsed) + time.sleep(sleep_time) + except KeyboardInterrupt: + logger.info("Stopping locomotion...") + finally: + if robot.is_connected: + robot.disconnect() + logger.info("Done!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Holosoma Locomotion Controller for Unitree G1") + parser.add_argument( + "--repo-id", + type=str, + default=DEFAULT_HOLOSOMA_REPO_ID, + help=f"Hugging Face Hub repo ID for Holosoma policies (default: {DEFAULT_HOLOSOMA_REPO_ID})", + ) + parser.add_argument( + "--policy", + type=str, + choices=["fastsac", "ppo"], + default="fastsac", + help="Policy type to use: 'fastsac' (default) or 'ppo'", + ) + args = parser.parse_args() + + run(repo_id=args.repo_id, policy_type=args.policy) diff --git a/src/lerobot/robots/unitree_g1/config_unitree_g1.py b/src/lerobot/robots/unitree_g1/config_unitree_g1.py index ac65f1a7b..c66edbd1c 100644 --- a/src/lerobot/robots/unitree_g1/config_unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/config_unitree_g1.py @@ -49,10 +49,14 @@ class UnitreeG1Config(RobotConfig): kp: list[float] = field(default_factory=lambda: _DEFAULT_KP.copy()) kd: list[float] = field(default_factory=lambda: _DEFAULT_KD.copy()) + # Default joint positions + default_positions: list[float] = field(default_factory=lambda: [0.0] * 29) + + # Control loop timestep control_dt: float = 1.0 / 250.0 # 250Hz - # launch mujoco simulation + # Launch mujoco simulation is_simulation: bool = True - # socket config for ZMQ bridge - robot_ip: str = "192.168.123.164" + # Socket config for ZMQ bridge + robot_ip: str = "192.168.123.164" # default G1 IP diff --git a/src/lerobot/robots/unitree_g1/g1_utils.py b/src/lerobot/robots/unitree_g1/g1_utils.py index c045d73d2..3c41ee985 100644 --- a/src/lerobot/robots/unitree_g1/g1_utils.py +++ b/src/lerobot/robots/unitree_g1/g1_utils.py @@ -79,11 +79,3 @@ class G1_29_JointIndex(IntEnum): kRightWristRoll = 26 kRightWristPitch = 27 kRightWristYaw = 28 - - # not used - kNotUsedJoint0 = 29 - kNotUsedJoint1 = 30 - kNotUsedJoint2 = 31 - kNotUsedJoint3 = 32 - kNotUsedJoint4 = 33 - kNotUsedJoint5 = 34 diff --git a/src/lerobot/robots/unitree_g1/unitree_g1.py b/src/lerobot/robots/unitree_g1/unitree_g1.py index cce9d1b1e..5bc4f3110 100644 --- a/src/lerobot/robots/unitree_g1/unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/unitree_g1.py @@ -43,10 +43,7 @@ logger = logging.getLogger(__name__) kTopicLowCommand_Debug = "rt/lowcmd" kTopicLowState = "rt/lowstate" -G1_29_Num_Motors = 35 -G1_23_Num_Motors = 35 -H1_2_Num_Motors = 35 -H1_Num_Motors = 20 +G1_29_Num_Motors = 29 @dataclass @@ -266,8 +263,17 @@ class UnitreeG1(Robot): return {**self._motors_ft, **self._cameras_ft} def send_action(self, action: dict[str, Any]) -> dict[str, Any]: - self.msg.crc = self.crc.Crc(action) - self.lowcmd_publisher.Write(action) + for motor in G1_29_JointIndex: + key = f"{motor.name}.q" + if key in action: + self.msg.motor_cmd[motor.value].q = action[key] + self.msg.motor_cmd[motor.value].qd = 0 + self.msg.motor_cmd[motor.value].kp = self.kp[motor.value] + self.msg.motor_cmd[motor.value].kd = self.kd[motor.value] + self.msg.motor_cmd[motor.value].tau = 0 + + self.msg.crc = self.crc.Crc(self.msg) + self.lowcmd_publisher.Write(self.msg) return action def get_gravity_orientation(self, quaternion): # get gravity orientation from quaternion @@ -282,3 +288,44 @@ class UnitreeG1(Robot): gravity_orientation[1] = -2 * (qz * qy + qw * qx) gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz) return gravity_orientation + + def reset( + self, + control_dt: float | None = None, + default_positions: list[float] | None = None, + ) -> None: # interpolate to default position + if control_dt is None: + control_dt = self.config.control_dt + if default_positions is None: + default_positions = np.array(self.config.default_positions, dtype=np.float32) + + total_time = 3.0 + num_steps = int(total_time / control_dt) + + # get current state + robot_state = self.get_observation() + + # record current positions + init_dof_pos = np.zeros(29, dtype=np.float32) + for i in range(29): + init_dof_pos[i] = robot_state.motor_state[i].q + + # Interpolate to default position + for step in range(num_steps): + start_time = time.time() + + alpha = step / num_steps + action_dict = {} + for motor in G1_29_JointIndex: + target_pos = default_positions[motor.value] + interp_pos = init_dof_pos[motor.value] * (1 - alpha) + target_pos * alpha + action_dict[f"{motor.name}.q"] = float(interp_pos) + + self.send_action(action_dict) + + # Maintain constant control rate + elapsed = time.time() - start_time + sleep_time = max(0, control_dt - elapsed) + time.sleep(sleep_time) + + logger.info("Reached default position") From f844c7a4585f01f4504af250adffa23990cfe195 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 7 Jan 2026 17:30:45 +0100 Subject: [PATCH 021/111] feat(visualization): allow remote viewer + compress rerun images (#2756) * feat(visualization): allow remote viewer + compress rerun images * fix(tests): allow named argument in mocked rerun * feat(visualization): ip instead or url & cli arg for compressing images --- src/lerobot/scripts/lerobot_dataset_viz.py | 14 ++++++++++-- src/lerobot/scripts/lerobot_record.py | 19 ++++++++++++++-- src/lerobot/scripts/lerobot_teleoperate.py | 17 ++++++++++++++- src/lerobot/utils/visualization_utils.py | 25 +++++++++++++++++----- tests/utils/test_visualization_utils.py | 5 ++++- 5 files changed, 69 insertions(+), 11 deletions(-) diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index 974762b0b..2cd48eab8 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -96,6 +96,7 @@ def visualize_dataset( ws_port: int = 9087, save: bool = False, output_dir: Path | None = None, + display_compressed_images: bool = False, ) -> Path | None: if save: assert output_dir is not None, ( @@ -137,8 +138,9 @@ def visualize_dataset( # display each camera image for key in dataset.meta.camera_keys: - # TODO(rcadene): add `.compress()`? is it lossless? - rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i]))) + img = to_hwc_uint8_numpy(batch[key][i]) + img_entity = rr.Image(img).compress() if display_compressed_images else rr.Image(img) + rr.log(key, entity=img_entity) # display each dimension of action space (e.g. actuators command) if ACTION in batch: @@ -261,6 +263,14 @@ def main(): ), ) + parser.add_argument( + "--display-compressed-images", + type=bool, + required=True, + default=False, + help="If set, display compressed images in Rerun instead of uncompressed ones.", + ) + args = parser.parse_args() kwargs = vars(args) repo_id = kwargs.pop("repo_id") diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 62bf21653..ced6b307f 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -185,6 +185,12 @@ class RecordConfig: policy: PreTrainedConfig | None = None # Display all cameras on screen display_data: bool = False + # Display data on a remote Rerun server + display_ip: str | None = None + # Port of the remote Rerun server + display_port: int | None = None + # Whether to display compressed images in Rerun + display_compressed_images: bool = False # Use vocal synthesis to read events. play_sounds: bool = True # Resume recording on an existing dataset. @@ -261,6 +267,7 @@ def record_loop( control_time_s: int | None = None, single_task: str | None = None, display_data: bool = False, + display_compressed_images: bool = False, ): if dataset is not None and dataset.fps != fps: raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).") @@ -371,7 +378,9 @@ def record_loop( dataset.add_frame(frame) if display_data: - log_rerun_data(observation=obs_processed, action=action_values) + log_rerun_data( + observation=obs_processed, action=action_values, compress_images=display_compressed_images + ) dt_s = time.perf_counter() - start_loop_t precise_sleep(max(1 / fps - dt_s, 0.0)) @@ -384,7 +393,12 @@ def record(cfg: RecordConfig) -> LeRobotDataset: init_logging() logging.info(pformat(asdict(cfg))) if cfg.display_data: - init_rerun(session_name="recording") + init_rerun(session_name="recording", ip=cfg.display_ip, port=cfg.display_port) + display_compressed_images = ( + True + if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None) + else cfg.display_compressed_images + ) robot = make_robot_from_config(cfg.robot) teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None @@ -478,6 +492,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset: control_time_s=cfg.dataset.episode_time_s, single_task=cfg.dataset.single_task, display_data=cfg.display_data, + display_compressed_images=display_compressed_images, ) # Execute a few seconds without recording to give time to manually reset the environment diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index 9866e7629..95025891e 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -108,6 +108,12 @@ class TeleoperateConfig: teleop_time_s: float | None = None # Display all cameras on screen display_data: bool = False + # Display data on a remote Rerun server + display_ip: str | None = None + # Port of the remote Rerun server + display_port: int | None = None + # Whether to display compressed images in Rerun + display_compressed_images: bool = False def teleop_loop( @@ -119,6 +125,7 @@ def teleop_loop( robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation], display_data: bool = False, duration: float | None = None, + display_compressed_images: bool = False, ): """ This function continuously reads actions from a teleoperation device, processes them through optional @@ -130,6 +137,7 @@ def teleop_loop( robot: The robot instance being controlled. fps: The target frequency for the control loop in frames per second. display_data: If True, fetches robot observations and displays them in the console and Rerun. + display_compressed_images: If True, compresses images before sending them to Rerun for display. duration: The maximum duration of the teleoperation loop in seconds. If None, the loop runs indefinitely. teleop_action_processor: An optional pipeline to process raw actions from the teleoperator. robot_action_processor: An optional pipeline to process actions before they are sent to the robot. @@ -167,6 +175,7 @@ def teleop_loop( log_rerun_data( observation=obs_transition, action=teleop_action, + compress_images=display_compressed_images, ) print("\n" + "-" * (display_len + 10)) @@ -191,7 +200,12 @@ def teleoperate(cfg: TeleoperateConfig): init_logging() logging.info(pformat(asdict(cfg))) if cfg.display_data: - init_rerun(session_name="teleoperation") + init_rerun(session_name="teleoperation", ip=cfg.display_ip, port=cfg.display_port) + display_compressed_images = ( + True + if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None) + else cfg.display_compressed_images + ) teleop = make_teleoperator_from_config(cfg.teleop) robot = make_robot_from_config(cfg.robot) @@ -210,6 +224,7 @@ def teleoperate(cfg: TeleoperateConfig): teleop_action_processor=teleop_action_processor, robot_action_processor=robot_action_processor, robot_observation_processor=robot_observation_processor, + display_compressed_images=display_compressed_images, ) except KeyboardInterrupt: pass diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index 991b10247..9143d0f66 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -22,13 +22,25 @@ import rerun as rr from .constants import OBS_PREFIX, OBS_STR -def init_rerun(session_name: str = "lerobot_control_loop") -> None: - """Initializes the Rerun SDK for visualizing the control loop.""" +def init_rerun( + session_name: str = "lerobot_control_loop", ip: str | None = None, port: int | None = None +) -> None: + """ + Initializes the Rerun SDK for visualizing the control loop. + + Args: + session_name: Name of the Rerun session. + ip: Optional IP for connecting to a Rerun server. + port: Optional port for connecting to a Rerun server. + """ batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000") os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size rr.init(session_name) memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%") - rr.spawn(memory_limit=memory_limit) + if ip and port: + rr.connect_grpc(url=f"rerun+http://{ip}:{port}/proxy") + else: + rr.spawn(memory_limit=memory_limit) def _is_scalar(x): @@ -40,6 +52,7 @@ def _is_scalar(x): def log_rerun_data( observation: dict[str, Any] | None = None, action: dict[str, Any] | None = None, + compress_images: bool = False, ) -> None: """ Logs observation and action data to Rerun for real-time visualization. @@ -48,7 +61,7 @@ def log_rerun_data( to the Rerun viewer. It handles different data types appropriately: - Scalars values (floats, ints) are logged as `rr.Scalars`. - 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed - from CHW to HWC format and logged as `rr.Image`. + from CHW to HWC format, (optionally) compressed to JPEG and logged as `rr.Image` or `rr.EncodedImage`. - 1D NumPy arrays are logged as a series of individual scalars, with each element indexed. - Other multi-dimensional arrays are flattened and logged as individual scalars. @@ -57,6 +70,7 @@ def log_rerun_data( Args: observation: An optional dictionary containing observation data to log. action: An optional dictionary containing action data to log. + compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality. """ if observation: for k, v in observation.items(): @@ -75,7 +89,8 @@ def log_rerun_data( for i, vi in enumerate(arr): rr.log(f"{key}_{i}", rr.Scalars(float(vi))) else: - rr.log(key, rr.Image(arr), static=True) + img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr) + rr.log(key, entity=img_entity, static=True) if action: for k, v in action.items(): diff --git a/tests/utils/test_visualization_utils.py b/tests/utils/test_visualization_utils.py index f573de166..408f636cb 100644 --- a/tests/utils/test_visualization_utils.py +++ b/tests/utils/test_visualization_utils.py @@ -41,7 +41,10 @@ def mock_rerun(monkeypatch): def __init__(self, arr): self.arr = arr - def dummy_log(key, obj, **kwargs): + def dummy_log(key, obj=None, **kwargs): + # Accept either positional `obj` or keyword `entity` and record remaining kwargs. + if obj is None and "entity" in kwargs: + obj = kwargs.pop("entity") calls.append((key, obj, kwargs)) dummy_rr = SimpleNamespace( From 4f7cd8d369586911d30d111684478956e47ffa18 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 7 Jan 2026 17:33:36 +0100 Subject: [PATCH 022/111] Revert "feat(visualization): allow remote viewer + compress rerun images (#2756)" (#2766) This reverts commit f844c7a4585f01f4504af250adffa23990cfe195. --- src/lerobot/scripts/lerobot_dataset_viz.py | 14 ++---------- src/lerobot/scripts/lerobot_record.py | 19 ++-------------- src/lerobot/scripts/lerobot_teleoperate.py | 17 +-------------- src/lerobot/utils/visualization_utils.py | 25 +++++----------------- tests/utils/test_visualization_utils.py | 5 +---- 5 files changed, 11 insertions(+), 69 deletions(-) diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index 2cd48eab8..974762b0b 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -96,7 +96,6 @@ def visualize_dataset( ws_port: int = 9087, save: bool = False, output_dir: Path | None = None, - display_compressed_images: bool = False, ) -> Path | None: if save: assert output_dir is not None, ( @@ -138,9 +137,8 @@ def visualize_dataset( # display each camera image for key in dataset.meta.camera_keys: - img = to_hwc_uint8_numpy(batch[key][i]) - img_entity = rr.Image(img).compress() if display_compressed_images else rr.Image(img) - rr.log(key, entity=img_entity) + # TODO(rcadene): add `.compress()`? is it lossless? + rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i]))) # display each dimension of action space (e.g. actuators command) if ACTION in batch: @@ -263,14 +261,6 @@ def main(): ), ) - parser.add_argument( - "--display-compressed-images", - type=bool, - required=True, - default=False, - help="If set, display compressed images in Rerun instead of uncompressed ones.", - ) - args = parser.parse_args() kwargs = vars(args) repo_id = kwargs.pop("repo_id") diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index ced6b307f..62bf21653 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -185,12 +185,6 @@ class RecordConfig: policy: PreTrainedConfig | None = None # Display all cameras on screen display_data: bool = False - # Display data on a remote Rerun server - display_ip: str | None = None - # Port of the remote Rerun server - display_port: int | None = None - # Whether to display compressed images in Rerun - display_compressed_images: bool = False # Use vocal synthesis to read events. play_sounds: bool = True # Resume recording on an existing dataset. @@ -267,7 +261,6 @@ def record_loop( control_time_s: int | None = None, single_task: str | None = None, display_data: bool = False, - display_compressed_images: bool = False, ): if dataset is not None and dataset.fps != fps: raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).") @@ -378,9 +371,7 @@ def record_loop( dataset.add_frame(frame) if display_data: - log_rerun_data( - observation=obs_processed, action=action_values, compress_images=display_compressed_images - ) + log_rerun_data(observation=obs_processed, action=action_values) dt_s = time.perf_counter() - start_loop_t precise_sleep(max(1 / fps - dt_s, 0.0)) @@ -393,12 +384,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset: init_logging() logging.info(pformat(asdict(cfg))) if cfg.display_data: - init_rerun(session_name="recording", ip=cfg.display_ip, port=cfg.display_port) - display_compressed_images = ( - True - if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None) - else cfg.display_compressed_images - ) + init_rerun(session_name="recording") robot = make_robot_from_config(cfg.robot) teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None @@ -492,7 +478,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset: control_time_s=cfg.dataset.episode_time_s, single_task=cfg.dataset.single_task, display_data=cfg.display_data, - display_compressed_images=display_compressed_images, ) # Execute a few seconds without recording to give time to manually reset the environment diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index 95025891e..9866e7629 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -108,12 +108,6 @@ class TeleoperateConfig: teleop_time_s: float | None = None # Display all cameras on screen display_data: bool = False - # Display data on a remote Rerun server - display_ip: str | None = None - # Port of the remote Rerun server - display_port: int | None = None - # Whether to display compressed images in Rerun - display_compressed_images: bool = False def teleop_loop( @@ -125,7 +119,6 @@ def teleop_loop( robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation], display_data: bool = False, duration: float | None = None, - display_compressed_images: bool = False, ): """ This function continuously reads actions from a teleoperation device, processes them through optional @@ -137,7 +130,6 @@ def teleop_loop( robot: The robot instance being controlled. fps: The target frequency for the control loop in frames per second. display_data: If True, fetches robot observations and displays them in the console and Rerun. - display_compressed_images: If True, compresses images before sending them to Rerun for display. duration: The maximum duration of the teleoperation loop in seconds. If None, the loop runs indefinitely. teleop_action_processor: An optional pipeline to process raw actions from the teleoperator. robot_action_processor: An optional pipeline to process actions before they are sent to the robot. @@ -175,7 +167,6 @@ def teleop_loop( log_rerun_data( observation=obs_transition, action=teleop_action, - compress_images=display_compressed_images, ) print("\n" + "-" * (display_len + 10)) @@ -200,12 +191,7 @@ def teleoperate(cfg: TeleoperateConfig): init_logging() logging.info(pformat(asdict(cfg))) if cfg.display_data: - init_rerun(session_name="teleoperation", ip=cfg.display_ip, port=cfg.display_port) - display_compressed_images = ( - True - if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None) - else cfg.display_compressed_images - ) + init_rerun(session_name="teleoperation") teleop = make_teleoperator_from_config(cfg.teleop) robot = make_robot_from_config(cfg.robot) @@ -224,7 +210,6 @@ def teleoperate(cfg: TeleoperateConfig): teleop_action_processor=teleop_action_processor, robot_action_processor=robot_action_processor, robot_observation_processor=robot_observation_processor, - display_compressed_images=display_compressed_images, ) except KeyboardInterrupt: pass diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index 9143d0f66..991b10247 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -22,25 +22,13 @@ import rerun as rr from .constants import OBS_PREFIX, OBS_STR -def init_rerun( - session_name: str = "lerobot_control_loop", ip: str | None = None, port: int | None = None -) -> None: - """ - Initializes the Rerun SDK for visualizing the control loop. - - Args: - session_name: Name of the Rerun session. - ip: Optional IP for connecting to a Rerun server. - port: Optional port for connecting to a Rerun server. - """ +def init_rerun(session_name: str = "lerobot_control_loop") -> None: + """Initializes the Rerun SDK for visualizing the control loop.""" batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000") os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size rr.init(session_name) memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%") - if ip and port: - rr.connect_grpc(url=f"rerun+http://{ip}:{port}/proxy") - else: - rr.spawn(memory_limit=memory_limit) + rr.spawn(memory_limit=memory_limit) def _is_scalar(x): @@ -52,7 +40,6 @@ def _is_scalar(x): def log_rerun_data( observation: dict[str, Any] | None = None, action: dict[str, Any] | None = None, - compress_images: bool = False, ) -> None: """ Logs observation and action data to Rerun for real-time visualization. @@ -61,7 +48,7 @@ def log_rerun_data( to the Rerun viewer. It handles different data types appropriately: - Scalars values (floats, ints) are logged as `rr.Scalars`. - 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed - from CHW to HWC format, (optionally) compressed to JPEG and logged as `rr.Image` or `rr.EncodedImage`. + from CHW to HWC format and logged as `rr.Image`. - 1D NumPy arrays are logged as a series of individual scalars, with each element indexed. - Other multi-dimensional arrays are flattened and logged as individual scalars. @@ -70,7 +57,6 @@ def log_rerun_data( Args: observation: An optional dictionary containing observation data to log. action: An optional dictionary containing action data to log. - compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality. """ if observation: for k, v in observation.items(): @@ -89,8 +75,7 @@ def log_rerun_data( for i, vi in enumerate(arr): rr.log(f"{key}_{i}", rr.Scalars(float(vi))) else: - img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr) - rr.log(key, entity=img_entity, static=True) + rr.log(key, rr.Image(arr), static=True) if action: for k, v in action.items(): diff --git a/tests/utils/test_visualization_utils.py b/tests/utils/test_visualization_utils.py index 408f636cb..f573de166 100644 --- a/tests/utils/test_visualization_utils.py +++ b/tests/utils/test_visualization_utils.py @@ -41,10 +41,7 @@ def mock_rerun(monkeypatch): def __init__(self, arr): self.arr = arr - def dummy_log(key, obj=None, **kwargs): - # Accept either positional `obj` or keyword `entity` and record remaining kwargs. - if obj is None and "entity" in kwargs: - obj = kwargs.pop("entity") + def dummy_log(key, obj, **kwargs): calls.append((key, obj, kwargs)) dummy_rr = SimpleNamespace( From fbe4c8b94fdb1d62773e0f285f75aebb5b9ec27d Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 7 Jan 2026 17:38:13 +0100 Subject: [PATCH 023/111] Feat/remote rerunviz encoded images (#2767) * feat(visualization): allow remote viewer + compress rerun images * fix(tests): allow named argument in mocked rerun * feat(visualization): ip instead or url & cli arg for compressing images --------- Co-authored-by: J4nn1K --- src/lerobot/scripts/lerobot_dataset_viz.py | 14 ++++++++++-- src/lerobot/scripts/lerobot_record.py | 19 ++++++++++++++-- src/lerobot/scripts/lerobot_teleoperate.py | 17 ++++++++++++++- src/lerobot/utils/visualization_utils.py | 25 +++++++++++++++++----- tests/utils/test_visualization_utils.py | 5 ++++- 5 files changed, 69 insertions(+), 11 deletions(-) diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index 974762b0b..2cd48eab8 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -96,6 +96,7 @@ def visualize_dataset( ws_port: int = 9087, save: bool = False, output_dir: Path | None = None, + display_compressed_images: bool = False, ) -> Path | None: if save: assert output_dir is not None, ( @@ -137,8 +138,9 @@ def visualize_dataset( # display each camera image for key in dataset.meta.camera_keys: - # TODO(rcadene): add `.compress()`? is it lossless? - rr.log(key, rr.Image(to_hwc_uint8_numpy(batch[key][i]))) + img = to_hwc_uint8_numpy(batch[key][i]) + img_entity = rr.Image(img).compress() if display_compressed_images else rr.Image(img) + rr.log(key, entity=img_entity) # display each dimension of action space (e.g. actuators command) if ACTION in batch: @@ -261,6 +263,14 @@ def main(): ), ) + parser.add_argument( + "--display-compressed-images", + type=bool, + required=True, + default=False, + help="If set, display compressed images in Rerun instead of uncompressed ones.", + ) + args = parser.parse_args() kwargs = vars(args) repo_id = kwargs.pop("repo_id") diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 62bf21653..ced6b307f 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -185,6 +185,12 @@ class RecordConfig: policy: PreTrainedConfig | None = None # Display all cameras on screen display_data: bool = False + # Display data on a remote Rerun server + display_ip: str | None = None + # Port of the remote Rerun server + display_port: int | None = None + # Whether to display compressed images in Rerun + display_compressed_images: bool = False # Use vocal synthesis to read events. play_sounds: bool = True # Resume recording on an existing dataset. @@ -261,6 +267,7 @@ def record_loop( control_time_s: int | None = None, single_task: str | None = None, display_data: bool = False, + display_compressed_images: bool = False, ): if dataset is not None and dataset.fps != fps: raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).") @@ -371,7 +378,9 @@ def record_loop( dataset.add_frame(frame) if display_data: - log_rerun_data(observation=obs_processed, action=action_values) + log_rerun_data( + observation=obs_processed, action=action_values, compress_images=display_compressed_images + ) dt_s = time.perf_counter() - start_loop_t precise_sleep(max(1 / fps - dt_s, 0.0)) @@ -384,7 +393,12 @@ def record(cfg: RecordConfig) -> LeRobotDataset: init_logging() logging.info(pformat(asdict(cfg))) if cfg.display_data: - init_rerun(session_name="recording") + init_rerun(session_name="recording", ip=cfg.display_ip, port=cfg.display_port) + display_compressed_images = ( + True + if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None) + else cfg.display_compressed_images + ) robot = make_robot_from_config(cfg.robot) teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None @@ -478,6 +492,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset: control_time_s=cfg.dataset.episode_time_s, single_task=cfg.dataset.single_task, display_data=cfg.display_data, + display_compressed_images=display_compressed_images, ) # Execute a few seconds without recording to give time to manually reset the environment diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index 9866e7629..95025891e 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -108,6 +108,12 @@ class TeleoperateConfig: teleop_time_s: float | None = None # Display all cameras on screen display_data: bool = False + # Display data on a remote Rerun server + display_ip: str | None = None + # Port of the remote Rerun server + display_port: int | None = None + # Whether to display compressed images in Rerun + display_compressed_images: bool = False def teleop_loop( @@ -119,6 +125,7 @@ def teleop_loop( robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation], display_data: bool = False, duration: float | None = None, + display_compressed_images: bool = False, ): """ This function continuously reads actions from a teleoperation device, processes them through optional @@ -130,6 +137,7 @@ def teleop_loop( robot: The robot instance being controlled. fps: The target frequency for the control loop in frames per second. display_data: If True, fetches robot observations and displays them in the console and Rerun. + display_compressed_images: If True, compresses images before sending them to Rerun for display. duration: The maximum duration of the teleoperation loop in seconds. If None, the loop runs indefinitely. teleop_action_processor: An optional pipeline to process raw actions from the teleoperator. robot_action_processor: An optional pipeline to process actions before they are sent to the robot. @@ -167,6 +175,7 @@ def teleop_loop( log_rerun_data( observation=obs_transition, action=teleop_action, + compress_images=display_compressed_images, ) print("\n" + "-" * (display_len + 10)) @@ -191,7 +200,12 @@ def teleoperate(cfg: TeleoperateConfig): init_logging() logging.info(pformat(asdict(cfg))) if cfg.display_data: - init_rerun(session_name="teleoperation") + init_rerun(session_name="teleoperation", ip=cfg.display_ip, port=cfg.display_port) + display_compressed_images = ( + True + if (cfg.display_data and cfg.display_ip is not None and cfg.display_port is not None) + else cfg.display_compressed_images + ) teleop = make_teleoperator_from_config(cfg.teleop) robot = make_robot_from_config(cfg.robot) @@ -210,6 +224,7 @@ def teleoperate(cfg: TeleoperateConfig): teleop_action_processor=teleop_action_processor, robot_action_processor=robot_action_processor, robot_observation_processor=robot_observation_processor, + display_compressed_images=display_compressed_images, ) except KeyboardInterrupt: pass diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index 991b10247..9143d0f66 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -22,13 +22,25 @@ import rerun as rr from .constants import OBS_PREFIX, OBS_STR -def init_rerun(session_name: str = "lerobot_control_loop") -> None: - """Initializes the Rerun SDK for visualizing the control loop.""" +def init_rerun( + session_name: str = "lerobot_control_loop", ip: str | None = None, port: int | None = None +) -> None: + """ + Initializes the Rerun SDK for visualizing the control loop. + + Args: + session_name: Name of the Rerun session. + ip: Optional IP for connecting to a Rerun server. + port: Optional port for connecting to a Rerun server. + """ batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000") os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size rr.init(session_name) memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%") - rr.spawn(memory_limit=memory_limit) + if ip and port: + rr.connect_grpc(url=f"rerun+http://{ip}:{port}/proxy") + else: + rr.spawn(memory_limit=memory_limit) def _is_scalar(x): @@ -40,6 +52,7 @@ def _is_scalar(x): def log_rerun_data( observation: dict[str, Any] | None = None, action: dict[str, Any] | None = None, + compress_images: bool = False, ) -> None: """ Logs observation and action data to Rerun for real-time visualization. @@ -48,7 +61,7 @@ def log_rerun_data( to the Rerun viewer. It handles different data types appropriately: - Scalars values (floats, ints) are logged as `rr.Scalars`. - 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed - from CHW to HWC format and logged as `rr.Image`. + from CHW to HWC format, (optionally) compressed to JPEG and logged as `rr.Image` or `rr.EncodedImage`. - 1D NumPy arrays are logged as a series of individual scalars, with each element indexed. - Other multi-dimensional arrays are flattened and logged as individual scalars. @@ -57,6 +70,7 @@ def log_rerun_data( Args: observation: An optional dictionary containing observation data to log. action: An optional dictionary containing action data to log. + compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality. """ if observation: for k, v in observation.items(): @@ -75,7 +89,8 @@ def log_rerun_data( for i, vi in enumerate(arr): rr.log(f"{key}_{i}", rr.Scalars(float(vi))) else: - rr.log(key, rr.Image(arr), static=True) + img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr) + rr.log(key, entity=img_entity, static=True) if action: for k, v in action.items(): diff --git a/tests/utils/test_visualization_utils.py b/tests/utils/test_visualization_utils.py index f573de166..408f636cb 100644 --- a/tests/utils/test_visualization_utils.py +++ b/tests/utils/test_visualization_utils.py @@ -41,7 +41,10 @@ def mock_rerun(monkeypatch): def __init__(self, arr): self.arr = arr - def dummy_log(key, obj, **kwargs): + def dummy_log(key, obj=None, **kwargs): + # Accept either positional `obj` or keyword `entity` and record remaining kwargs. + if obj is None and "entity" in kwargs: + obj = kwargs.pop("entity") calls.append((key, obj, kwargs)) dummy_rr = SimpleNamespace( From ccfd609ece9d62746b50cbbed6146da7a60c5771 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 8 Jan 2026 13:04:30 +0100 Subject: [PATCH 024/111] feat(robots): consolidate SO arms implementation (#2763) * feat(robots): consolidate SO arms implementation * chore(robots): delete unnecessary init modules --- docs/source/async.mdx | 2 +- docs/source/envhub_leisaac.mdx | 4 +- docs/source/il_robots.mdx | 18 +- docs/source/integrate_hardware.mdx | 2 +- docs/source/lekiwi.mdx | 2 +- docs/source/so100.mdx | 8 +- docs/source/so101.mdx | 8 +- examples/backward_compatibility/replay.py | 3 +- examples/lekiwi/record.py | 2 +- examples/lekiwi/teleoperate.py | 2 +- examples/phone_to_so100/evaluate.py | 5 +- examples/phone_to_so100/record.py | 5 +- examples/phone_to_so100/replay.py | 5 +- examples/phone_to_so100/teleoperate.py | 5 +- examples/rtc/eval_with_real_robot.py | 3 +- examples/so100_to_so100_EE/evaluate.py | 5 +- examples/so100_to_so100_EE/record.py | 9 +- examples/so100_to_so100_EE/replay.py | 5 +- examples/so100_to_so100_EE/teleoperate.py | 9 +- examples/tutorial/act/act_using_example.py | 3 +- examples/tutorial/async-inf/robot_client.py | 2 +- .../diffusion/diffusion_using_example.py | 3 +- examples/tutorial/pi0/using_pi0_example.py | 3 +- examples/tutorial/rl/hilserl_example.py | 4 +- .../tutorial/smolvla/using_smolvla_example.py | 3 +- src/lerobot/async_inference/robot_client.py | 3 +- src/lerobot/rl/actor.py | 4 +- src/lerobot/rl/eval_policy.py | 4 +- src/lerobot/rl/gym_manipulator.py | 6 +- src/lerobot/rl/learner.py | 4 +- .../bi_so100_follower/bi_so100_follower.py | 3 +- .../so100_follower/config_so100_follower.py | 41 ---- src/lerobot/robots/so100_follower/so100.mdx | 1 - src/lerobot/robots/so101_follower/so101.mdx | 1 - .../robots/so101_follower/so101_follower.py | 230 ------------------ src/lerobot/robots/so_follower/__init__.py | 23 ++ .../robot_kinematic_processor.py | 0 .../so100_follower/config_so100_follower.py} | 14 +- .../so_follower/so100_follower/so100.md | 1 + .../so100_follower/so100_follower.py | 27 ++ .../so101_follower/config_so101_follower.py} | 14 +- .../so_follower/so101_follower/so101.md | 1 + .../so101_follower/so101_follower.py} | 13 +- .../so_follower_base.py} | 16 +- .../so_follower_config_base.py} | 5 +- src/lerobot/robots/utils.py | 4 +- src/lerobot/scripts/lerobot_calibrate.py | 6 +- .../scripts/lerobot_find_joint_limits.py | 6 +- src/lerobot/scripts/lerobot_record.py | 10 +- src/lerobot/scripts/lerobot_replay.py | 3 +- src/lerobot/scripts/lerobot_setup_motors.py | 6 +- src/lerobot/scripts/lerobot_teleoperate.py | 6 +- .../bi_so100_leader/bi_so100_leader.py | 3 +- .../so100_leader/so100_leader.py | 159 ------------ .../teleoperators/so_leader/__init__.py | 22 ++ .../so100_leader/config_so100_leader.py | 8 +- .../so_leader/so100_leader/so100.md | 1 + .../so_leader/so100_leader/so100_leader.py | 27 ++ .../so101_leader/config_so101_leader.py | 10 +- .../so_leader/so101_leader/so101.md | 1 + .../so_leader/so101_leader/so101_leader.py | 27 ++ .../so_leader_base.py} | 29 ++- .../so_leader/so_leader_config_base.py} | 16 +- src/lerobot/teleoperators/utils.py | 4 +- tests/robots/test_so100_follower.py | 4 +- 65 files changed, 295 insertions(+), 588 deletions(-) delete mode 100644 src/lerobot/robots/so100_follower/config_so100_follower.py delete mode 120000 src/lerobot/robots/so100_follower/so100.mdx delete mode 120000 src/lerobot/robots/so101_follower/so101.mdx delete mode 100644 src/lerobot/robots/so101_follower/so101_follower.py create mode 100644 src/lerobot/robots/so_follower/__init__.py rename src/lerobot/robots/{so100_follower => so_follower}/robot_kinematic_processor.py (100%) rename src/lerobot/{teleoperators/so101_leader/__init__.py => robots/so_follower/so100_follower/config_so100_follower.py} (64%) create mode 120000 src/lerobot/robots/so_follower/so100_follower/so100.md create mode 100644 src/lerobot/robots/so_follower/so100_follower/so100_follower.py rename src/lerobot/{teleoperators/so100_leader/__init__.py => robots/so_follower/so101_follower/config_so101_follower.py} (64%) create mode 120000 src/lerobot/robots/so_follower/so101_follower/so101.md rename src/lerobot/robots/{so101_follower/__init__.py => so_follower/so101_follower/so101_follower.py} (63%) rename src/lerobot/robots/{so100_follower/so100_follower.py => so_follower/so_follower_base.py} (93%) rename src/lerobot/robots/{so101_follower/config_so101_follower.py => so_follower/so_follower_config_base.py} (93%) delete mode 100644 src/lerobot/teleoperators/so100_leader/so100_leader.py create mode 100644 src/lerobot/teleoperators/so_leader/__init__.py rename src/lerobot/teleoperators/{ => so_leader}/so100_leader/config_so100_leader.py (83%) create mode 120000 src/lerobot/teleoperators/so_leader/so100_leader/so100.md create mode 100644 src/lerobot/teleoperators/so_leader/so100_leader/so100_leader.py rename src/lerobot/teleoperators/{ => so_leader}/so101_leader/config_so101_leader.py (81%) create mode 120000 src/lerobot/teleoperators/so_leader/so101_leader/so101.md create mode 100644 src/lerobot/teleoperators/so_leader/so101_leader/so101_leader.py rename src/lerobot/teleoperators/{so101_leader/so101_leader.py => so_leader/so_leader_base.py} (87%) rename src/lerobot/{robots/so100_follower/__init__.py => teleoperators/so_leader/so_leader_config_base.py} (66%) diff --git a/docs/source/async.mdx b/docs/source/async.mdx index 9dd87472c..1d3e0edbf 100644 --- a/docs/source/async.mdx +++ b/docs/source/async.mdx @@ -169,7 +169,7 @@ python -m lerobot.async_inference.robot_client \ ```python import threading -from lerobot.robots.so100_follower import SO100FollowerConfig +from lerobot.robots.so_follower import SO100FollowerConfig from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.async_inference.configs import RobotClientConfig from lerobot.async_inference.robot_client import RobotClient diff --git a/docs/source/envhub_leisaac.mdx b/docs/source/envhub_leisaac.mdx index 5cf4a0e45..1fc74c0fa 100644 --- a/docs/source/envhub_leisaac.mdx +++ b/docs/source/envhub_leisaac.mdx @@ -137,7 +137,7 @@ from lerobot.teleoperators import ( # noqa: F401 Teleoperator, TeleoperatorConfig, make_teleoperator_from_config, - so101_leader, + so_leader, ) from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import init_logging @@ -222,7 +222,7 @@ def teleoperate(cfg: TeleoperateConfig): def main(): teleoperate(TeleoperateConfig( - teleop=so101_leader.SO101LeaderConfig( + teleop=so_leader.SO101LeaderConfig( port="/dev/ttyACM0", id='leader', use_degrees=False, diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index eb779d2b1..84dc6f2f6 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -58,8 +58,8 @@ lerobot-teleoperate \ ```python -from lerobot.teleoperators.so101_leader import SO101LeaderConfig, SO101Leader -from lerobot.robots.so101_follower import SO101FollowerConfig, SO101Follower +from lerobot.teleoperators.so_leader import SO101LeaderConfig, SO101Leader +from lerobot.robots.so_follower import SO101FollowerConfig, SO101Follower robot_config = SO101FollowerConfig( port="/dev/tty.usbmodem58760431541", @@ -195,9 +195,9 @@ lerobot-record \ from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import hw_to_dataset_features -from lerobot.robots.so100_follower import SO100Follower, SO100FollowerConfig -from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig -from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader +from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig +from lerobot.teleoperators.so_leader.config_so100_leader import SO100LeaderConfig +from lerobot.teleoperators.so_leader.so100_leader import SO100Leader from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun @@ -408,8 +408,8 @@ lerobot-replay \ import time from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig -from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.robots.so_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so_follower.so100_follower import SO100Follower from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say @@ -531,8 +531,8 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import hw_to_dataset_features from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.factory import make_pre_post_processors -from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig -from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.robots.so_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so_follower.so100_follower import SO100Follower from lerobot.scripts.lerobot_record import record_loop from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say diff --git a/docs/source/integrate_hardware.mdx b/docs/source/integrate_hardware.mdx index e1587be91..fa36e7170 100644 --- a/docs/source/integrate_hardware.mdx +++ b/docs/source/integrate_hardware.mdx @@ -18,7 +18,7 @@ If you're using Feetech or Dynamixel motors, LeRobot provides built-in bus inter - [`DynamixelMotorsBus`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/dynamixel/dynamixel.py) – for controlling Dynamixel servos Please refer to the [`MotorsBus`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/motors/motors_bus.py) abstract class to learn about its API. -For a good example of how it can be used, you can have a look at our own [SO101 follower implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/robots/so101_follower/so101_follower.py) +For a good example of how it can be used, you can have a look at our own [SO101 follower implementation](https://github.com/huggingface/lerobot/blob/main/src/lerobot/robots/so_follower/so101_follower/so101_follower.py) Use these if compatible. Otherwise, you'll need to find or write a Python interface (not covered in this tutorial): diff --git a/docs/source/lekiwi.mdx b/docs/source/lekiwi.mdx index 875394d71..511521580 100644 --- a/docs/source/lekiwi.mdx +++ b/docs/source/lekiwi.mdx @@ -204,7 +204,7 @@ lerobot-calibrate \ ```python -from lerobot.teleoperators.so100_leader import SO100LeaderConfig, SO100Leader +from lerobot.teleoperators.so_leader import SO100LeaderConfig, SO100Leader config = SO100LeaderConfig( port="/dev/tty.usbmodem58760431551", diff --git a/docs/source/so100.mdx b/docs/source/so100.mdx index 3c73ae801..399781ef4 100644 --- a/docs/source/so100.mdx +++ b/docs/source/so100.mdx @@ -103,7 +103,7 @@ lerobot-setup-motors \ ```python -from lerobot.robots.so100_follower import SO100Follower, SO100FollowerConfig +from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig config = SO100FollowerConfig( port="/dev/tty.usbmodem585A0076841", @@ -177,7 +177,7 @@ lerobot-setup-motors \ ```python -from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig +from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig config = SO100LeaderConfig( port="/dev/tty.usbmodem585A0076841", @@ -579,7 +579,7 @@ lerobot-calibrate \ ```python -from lerobot.robots.so100_follower import SO100FollowerConfig, SO100Follower +from lerobot.robots.so_follower import SO100FollowerConfig, SO100Follower config = SO100FollowerConfig( port="/dev/tty.usbmodem585A0076891", @@ -617,7 +617,7 @@ lerobot-calibrate \ ```python -from lerobot.teleoperators.so100_leader import SO100LeaderConfig, SO100Leader +from lerobot.teleoperators.so_leader import SO100LeaderConfig, SO100Leader config = SO100LeaderConfig( port="/dev/tty.usbmodem58760431551", diff --git a/docs/source/so101.mdx b/docs/source/so101.mdx index 57e8d691d..cf882b373 100644 --- a/docs/source/so101.mdx +++ b/docs/source/so101.mdx @@ -125,7 +125,7 @@ lerobot-setup-motors \ ```python -from lerobot.robots.so101_follower import SO101Follower, SO101FollowerConfig +from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig config = SO101FollowerConfig( port="/dev/tty.usbmodem585A0076841", @@ -201,7 +201,7 @@ lerobot-setup-motors \ ```python -from lerobot.teleoperators.so101_leader import SO101Leader, SO101LeaderConfig +from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig config = SO101LeaderConfig( port="/dev/tty.usbmodem585A0076841", @@ -364,7 +364,7 @@ lerobot-calibrate \ ```python -from lerobot.robots.so101_follower import SO101FollowerConfig, SO101Follower +from lerobot.robots.so_follower import SO101FollowerConfig, SO101Follower config = SO101FollowerConfig( port="/dev/tty.usbmodem585A0076891", @@ -413,7 +413,7 @@ lerobot-calibrate \ ```python -from lerobot.teleoperators.so101_leader import SO101LeaderConfig, SO101Leader +from lerobot.teleoperators.so_leader import SO101LeaderConfig, SO101Leader config = SO101LeaderConfig( port="/dev/tty.usbmodem58760431551", diff --git a/examples/backward_compatibility/replay.py b/examples/backward_compatibility/replay.py index 85f3ecef7..ed78d016f 100644 --- a/examples/backward_compatibility/replay.py +++ b/examples/backward_compatibility/replay.py @@ -41,8 +41,7 @@ from lerobot.robots import ( # noqa: F401 RobotConfig, koch_follower, make_robot_from_config, - so100_follower, - so101_follower, + so_follower, ) from lerobot.utils.constants import ACTION from lerobot.utils.robot_utils import precise_sleep diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index 67d826ccb..18b9f857e 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -21,7 +21,7 @@ from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient from lerobot.scripts.lerobot_record import record_loop from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig -from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig +from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say diff --git a/examples/lekiwi/teleoperate.py b/examples/lekiwi/teleoperate.py index c4d20ebbe..feb3cbb01 100644 --- a/examples/lekiwi/teleoperate.py +++ b/examples/lekiwi/teleoperate.py @@ -18,7 +18,7 @@ import time from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop, KeyboardTeleopConfig -from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig +from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.visualization_utils import init_rerun, log_rerun_data diff --git a/examples/phone_to_so100/evaluate.py b/examples/phone_to_so100/evaluate.py index 5a47b8ffa..246c923aa 100644 --- a/examples/phone_to_so100/evaluate.py +++ b/examples/phone_to_so100/evaluate.py @@ -34,12 +34,11 @@ from lerobot.processor.converters import ( transition_to_observation, transition_to_robot_action, ) -from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig -from lerobot.robots.so100_follower.robot_kinematic_processor import ( +from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig +from lerobot.robots.so_follower.robot_kinematic_processor import ( ForwardKinematicsJointsToEE, InverseKinematicsEEToJoints, ) -from lerobot.robots.so100_follower.so100_follower import SO100Follower from lerobot.scripts.lerobot_record import record_loop from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py index e563d8eb3..7b5b704e2 100644 --- a/examples/phone_to_so100/record.py +++ b/examples/phone_to_so100/record.py @@ -26,15 +26,14 @@ from lerobot.processor.converters import ( transition_to_observation, transition_to_robot_action, ) -from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig -from lerobot.robots.so100_follower.robot_kinematic_processor import ( +from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig +from lerobot.robots.so_follower.robot_kinematic_processor import ( EEBoundsAndSafety, EEReferenceAndDelta, ForwardKinematicsJointsToEE, GripperVelocityToJoint, InverseKinematicsEEToJoints, ) -from lerobot.robots.so100_follower.so100_follower import SO100Follower from lerobot.scripts.lerobot_record import record_loop from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction diff --git a/examples/phone_to_so100/replay.py b/examples/phone_to_so100/replay.py index ab6ce3ced..875025dfc 100644 --- a/examples/phone_to_so100/replay.py +++ b/examples/phone_to_so100/replay.py @@ -23,11 +23,10 @@ from lerobot.processor.converters import ( robot_action_observation_to_transition, transition_to_robot_action, ) -from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig -from lerobot.robots.so100_follower.robot_kinematic_processor import ( +from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig +from lerobot.robots.so_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) -from lerobot.robots.so100_follower.so100_follower import SO100Follower from lerobot.utils.constants import ACTION from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say diff --git a/examples/phone_to_so100/teleoperate.py b/examples/phone_to_so100/teleoperate.py index 2ac8b3cce..6eaaec806 100644 --- a/examples/phone_to_so100/teleoperate.py +++ b/examples/phone_to_so100/teleoperate.py @@ -21,14 +21,13 @@ from lerobot.processor.converters import ( robot_action_observation_to_transition, transition_to_robot_action, ) -from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig -from lerobot.robots.so100_follower.robot_kinematic_processor import ( +from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig +from lerobot.robots.so_follower.robot_kinematic_processor import ( EEBoundsAndSafety, EEReferenceAndDelta, GripperVelocityToJoint, InverseKinematicsEEToJoints, ) -from lerobot.robots.so100_follower.so100_follower import SO100Follower from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction from lerobot.teleoperators.phone.teleop_phone import Phone diff --git a/examples/rtc/eval_with_real_robot.py b/examples/rtc/eval_with_real_robot.py index 5f44649da..991d2468d 100644 --- a/examples/rtc/eval_with_real_robot.py +++ b/examples/rtc/eval_with_real_robot.py @@ -95,8 +95,7 @@ from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, koch_follower, - so100_follower, - so101_follower, + so_follower, ) from lerobot.robots.utils import make_robot_from_config from lerobot.utils.constants import OBS_IMAGES diff --git a/examples/so100_to_so100_EE/evaluate.py b/examples/so100_to_so100_EE/evaluate.py index 90973d373..87d188f99 100644 --- a/examples/so100_to_so100_EE/evaluate.py +++ b/examples/so100_to_so100_EE/evaluate.py @@ -34,12 +34,11 @@ from lerobot.processor.converters import ( transition_to_observation, transition_to_robot_action, ) -from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig -from lerobot.robots.so100_follower.robot_kinematic_processor import ( +from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig +from lerobot.robots.so_follower.robot_kinematic_processor import ( ForwardKinematicsJointsToEE, InverseKinematicsEEToJoints, ) -from lerobot.robots.so100_follower.so100_follower import SO100Follower from lerobot.scripts.lerobot_record import record_loop from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say diff --git a/examples/so100_to_so100_EE/record.py b/examples/so100_to_so100_EE/record.py index 6bfdfe32d..db24f4b93 100644 --- a/examples/so100_to_so100_EE/record.py +++ b/examples/so100_to_so100_EE/record.py @@ -27,16 +27,15 @@ from lerobot.processor.converters import ( transition_to_observation, transition_to_robot_action, ) -from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig -from lerobot.robots.so100_follower.robot_kinematic_processor import ( +from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig +from lerobot.robots.so_follower.robot_kinematic_processor import ( EEBoundsAndSafety, ForwardKinematicsJointsToEE, InverseKinematicsEEToJoints, ) -from lerobot.robots.so100_follower.so100_follower import SO100Follower from lerobot.scripts.lerobot_record import record_loop -from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig -from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader +from lerobot.teleoperators.so_leader import SO100LeaderConfig +from lerobot.teleoperators.so_leader.so100_leader import SO100Leader from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun diff --git a/examples/so100_to_so100_EE/replay.py b/examples/so100_to_so100_EE/replay.py index 59c524078..7d35a7b44 100644 --- a/examples/so100_to_so100_EE/replay.py +++ b/examples/so100_to_so100_EE/replay.py @@ -24,11 +24,10 @@ from lerobot.processor.converters import ( robot_action_observation_to_transition, transition_to_robot_action, ) -from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig -from lerobot.robots.so100_follower.robot_kinematic_processor import ( +from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig +from lerobot.robots.so_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) -from lerobot.robots.so100_follower.so100_follower import SO100Follower from lerobot.utils.constants import ACTION from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say diff --git a/examples/so100_to_so100_EE/teleoperate.py b/examples/so100_to_so100_EE/teleoperate.py index 21299103b..d520a6eaf 100644 --- a/examples/so100_to_so100_EE/teleoperate.py +++ b/examples/so100_to_so100_EE/teleoperate.py @@ -23,15 +23,14 @@ from lerobot.processor.converters import ( robot_action_to_transition, transition_to_robot_action, ) -from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig -from lerobot.robots.so100_follower.robot_kinematic_processor import ( +from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig +from lerobot.robots.so_follower.robot_kinematic_processor import ( EEBoundsAndSafety, ForwardKinematicsJointsToEE, InverseKinematicsEEToJoints, ) -from lerobot.robots.so100_follower.so100_follower import SO100Follower -from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig -from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader +from lerobot.teleoperators.so_leader import SO100LeaderConfig +from lerobot.teleoperators.so_leader.so100_leader import SO100Leader from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.visualization_utils import init_rerun, log_rerun_data diff --git a/examples/tutorial/act/act_using_example.py b/examples/tutorial/act/act_using_example.py index b268e8790..60bc802d8 100644 --- a/examples/tutorial/act/act_using_example.py +++ b/examples/tutorial/act/act_using_example.py @@ -5,8 +5,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.policies.act.modeling_act import ACTPolicy from lerobot.policies.factory import make_pre_post_processors from lerobot.policies.utils import build_inference_frame, make_robot_action -from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig -from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig MAX_EPISODES = 5 MAX_STEPS_PER_EPISODE = 20 diff --git a/examples/tutorial/async-inf/robot_client.py b/examples/tutorial/async-inf/robot_client.py index fff7b15b3..eb3751169 100644 --- a/examples/tutorial/async-inf/robot_client.py +++ b/examples/tutorial/async-inf/robot_client.py @@ -4,7 +4,7 @@ from lerobot.async_inference.configs import RobotClientConfig from lerobot.async_inference.helpers import visualize_action_queue_size from lerobot.async_inference.robot_client import RobotClient from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig -from lerobot.robots.so100_follower import SO100FollowerConfig +from lerobot.robots.so_follower import SO100FollowerConfig def main(): diff --git a/examples/tutorial/diffusion/diffusion_using_example.py b/examples/tutorial/diffusion/diffusion_using_example.py index 96cc607b6..d8ac75cfe 100644 --- a/examples/tutorial/diffusion/diffusion_using_example.py +++ b/examples/tutorial/diffusion/diffusion_using_example.py @@ -5,8 +5,7 @@ from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy from lerobot.policies.factory import make_pre_post_processors from lerobot.policies.utils import build_inference_frame, make_robot_action -from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig -from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig MAX_EPISODES = 5 MAX_STEPS_PER_EPISODE = 20 diff --git a/examples/tutorial/pi0/using_pi0_example.py b/examples/tutorial/pi0/using_pi0_example.py index 362092ccf..056c3d81a 100644 --- a/examples/tutorial/pi0/using_pi0_example.py +++ b/examples/tutorial/pi0/using_pi0_example.py @@ -5,8 +5,7 @@ from lerobot.datasets.utils import hw_to_dataset_features from lerobot.policies.factory import make_pre_post_processors from lerobot.policies.pi0.modeling_pi0 import PI0Policy from lerobot.policies.utils import build_inference_frame, make_robot_action -from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig -from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig MAX_EPISODES = 5 MAX_STEPS_PER_EPISODE = 20 diff --git a/examples/tutorial/rl/hilserl_example.py b/examples/tutorial/rl/hilserl_example.py index c49233ebb..980ac7985 100644 --- a/examples/tutorial/rl/hilserl_example.py +++ b/examples/tutorial/rl/hilserl_example.py @@ -14,8 +14,8 @@ from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.policies.sac.reward_model.modeling_classifier import Classifier from lerobot.rl.buffer import ReplayBuffer from lerobot.rl.gym_manipulator import make_robot_env -from lerobot.robots.so100_follower import SO100FollowerConfig -from lerobot.teleoperators.so100_leader import SO100LeaderConfig +from lerobot.robots.so_follower import SO100FollowerConfig +from lerobot.teleoperators.so_leader import SO100LeaderConfig from lerobot.teleoperators.utils import TeleopEvents LOG_EVERY = 10 diff --git a/examples/tutorial/smolvla/using_smolvla_example.py b/examples/tutorial/smolvla/using_smolvla_example.py index d4219f316..ce3aa7bca 100644 --- a/examples/tutorial/smolvla/using_smolvla_example.py +++ b/examples/tutorial/smolvla/using_smolvla_example.py @@ -5,8 +5,7 @@ from lerobot.datasets.utils import hw_to_dataset_features from lerobot.policies.factory import make_pre_post_processors from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy from lerobot.policies.utils import build_inference_frame, make_robot_action -from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig -from lerobot.robots.so100_follower.so100_follower import SO100Follower +from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig MAX_EPISODES = 5 MAX_STEPS_PER_EPISODE = 20 diff --git a/src/lerobot/async_inference/robot_client.py b/src/lerobot/async_inference/robot_client.py index d32aa6a21..c3668d40b 100644 --- a/src/lerobot/async_inference/robot_client.py +++ b/src/lerobot/async_inference/robot_client.py @@ -55,8 +55,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, make_robot_from_config, omx_follower, - so100_follower, - so101_follower, + so_follower, ) from lerobot.transport import ( services_pb2, # type: ignore diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index 641a21d03..7427633d2 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -65,8 +65,8 @@ from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.processor import TransitionKey from lerobot.rl.process import ProcessSignalHandler from lerobot.rl.queue import get_last_item_from_queue -from lerobot.robots import so100_follower # noqa: F401 -from lerobot.teleoperators import gamepad, so101_leader # noqa: F401 +from lerobot.robots import so_follower # noqa: F401 +from lerobot.teleoperators import gamepad, so_leader # noqa: F401 from lerobot.teleoperators.utils import TeleopEvents from lerobot.transport import services_pb2, services_pb2_grpc from lerobot.transport.utils import ( diff --git a/src/lerobot/rl/eval_policy.py b/src/lerobot/rl/eval_policy.py index 16bb64a73..fb2504f2a 100644 --- a/src/lerobot/rl/eval_policy.py +++ b/src/lerobot/rl/eval_policy.py @@ -23,11 +23,11 @@ from lerobot.policies.factory import make_policy from lerobot.robots import ( # noqa: F401 RobotConfig, make_robot_from_config, - so100_follower, + so_follower, ) from lerobot.teleoperators import ( gamepad, # noqa: F401 - so101_leader, # noqa: F401 + so_leader, # noqa: F401 ) from .gym_manipulator import make_robot_env diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index 7bc74a959..604adb931 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -55,10 +55,10 @@ from lerobot.processor.converters import identity_transition from lerobot.robots import ( # noqa: F401 RobotConfig, make_robot_from_config, - so100_follower, + so_follower, ) from lerobot.robots.robot import Robot -from lerobot.robots.so100_follower.robot_kinematic_processor import ( +from lerobot.robots.so_follower.robot_kinematic_processor import ( EEBoundsAndSafety, EEReferenceAndDelta, ForwardKinematicsJointsToEEObservation, @@ -69,7 +69,7 @@ from lerobot.teleoperators import ( gamepad, # noqa: F401 keyboard, # noqa: F401 make_teleoperator_from_config, - so101_leader, # noqa: F401 + so_leader, # noqa: F401 ) from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.teleoperators.utils import TeleopEvents diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index d9758d3a3..abc5c9504 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -69,8 +69,8 @@ from lerobot.policies.sac.modeling_sac import SACPolicy from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions from lerobot.rl.process import ProcessSignalHandler from lerobot.rl.wandb_utils import WandBLogger -from lerobot.robots import so100_follower # noqa: F401 -from lerobot.teleoperators import gamepad, so101_leader # noqa: F401 +from lerobot.robots import so_follower # noqa: F401 +from lerobot.teleoperators import gamepad, so_leader # noqa: F401 from lerobot.teleoperators.utils import TeleopEvents from lerobot.transport import services_pb2_grpc from lerobot.transport.utils import ( diff --git a/src/lerobot/robots/bi_so100_follower/bi_so100_follower.py b/src/lerobot/robots/bi_so100_follower/bi_so100_follower.py index 7992b79fd..87a7edcc5 100644 --- a/src/lerobot/robots/bi_so100_follower/bi_so100_follower.py +++ b/src/lerobot/robots/bi_so100_follower/bi_so100_follower.py @@ -20,8 +20,7 @@ from functools import cached_property from typing import Any from lerobot.cameras.utils import make_cameras_from_configs -from lerobot.robots.so100_follower import SO100Follower -from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from ..robot import Robot from .config_bi_so100_follower import BiSO100FollowerConfig diff --git a/src/lerobot/robots/so100_follower/config_so100_follower.py b/src/lerobot/robots/so100_follower/config_so100_follower.py deleted file mode 100644 index 272b8c43f..000000000 --- a/src/lerobot/robots/so100_follower/config_so100_follower.py +++ /dev/null @@ -1,41 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass, field - -from lerobot.cameras import CameraConfig - -from ..config import RobotConfig - - -@RobotConfig.register_subclass("so100_follower") -@dataclass -class SO100FollowerConfig(RobotConfig): - # Port to connect to the arm - port: str - - disable_torque_on_disconnect: bool = True - - # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a dictionary that maps motor - # names to the max_relative_target value for that motor. - max_relative_target: float | dict[str, float] | None = None - - # cameras - cameras: dict[str, CameraConfig] = field(default_factory=dict) - - # Set to `True` for backward compatibility with previous policies/dataset - use_degrees: bool = False diff --git a/src/lerobot/robots/so100_follower/so100.mdx b/src/lerobot/robots/so100_follower/so100.mdx deleted file mode 120000 index ad1154e75..000000000 --- a/src/lerobot/robots/so100_follower/so100.mdx +++ /dev/null @@ -1 +0,0 @@ -../../../../docs/source/so100.mdx \ No newline at end of file diff --git a/src/lerobot/robots/so101_follower/so101.mdx b/src/lerobot/robots/so101_follower/so101.mdx deleted file mode 120000 index 27b892660..000000000 --- a/src/lerobot/robots/so101_follower/so101.mdx +++ /dev/null @@ -1 +0,0 @@ -../../../../docs/source/so101.mdx \ No newline at end of file diff --git a/src/lerobot/robots/so101_follower/so101_follower.py b/src/lerobot/robots/so101_follower/so101_follower.py deleted file mode 100644 index acfd4bd11..000000000 --- a/src/lerobot/robots/so101_follower/so101_follower.py +++ /dev/null @@ -1,230 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import time -from functools import cached_property -from typing import Any - -from lerobot.cameras.utils import make_cameras_from_configs -from lerobot.motors import Motor, MotorCalibration, MotorNormMode -from lerobot.motors.feetech import ( - FeetechMotorsBus, - OperatingMode, -) -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError - -from ..robot import Robot -from ..utils import ensure_safe_goal_position -from .config_so101_follower import SO101FollowerConfig - -logger = logging.getLogger(__name__) - - -class SO101Follower(Robot): - """ - SO-101 Follower Arm designed by TheRobotStudio and Hugging Face. - """ - - config_class = SO101FollowerConfig - name = "so101_follower" - - def __init__(self, config: SO101FollowerConfig): - super().__init__(config) - self.config = config - norm_mode_body = MotorNormMode.DEGREES if config.use_degrees else MotorNormMode.RANGE_M100_100 - self.bus = FeetechMotorsBus( - port=self.config.port, - motors={ - "shoulder_pan": Motor(1, "sts3215", norm_mode_body), - "shoulder_lift": Motor(2, "sts3215", norm_mode_body), - "elbow_flex": Motor(3, "sts3215", norm_mode_body), - "wrist_flex": Motor(4, "sts3215", norm_mode_body), - "wrist_roll": Motor(5, "sts3215", norm_mode_body), - "gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100), - }, - calibration=self.calibration, - ) - self.cameras = make_cameras_from_configs(config.cameras) - - @property - def _motors_ft(self) -> dict[str, type]: - return {f"{motor}.pos": float for motor in self.bus.motors} - - @property - def _cameras_ft(self) -> dict[str, tuple]: - return { - cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras - } - - @cached_property - def observation_features(self) -> dict[str, type | tuple]: - return {**self._motors_ft, **self._cameras_ft} - - @cached_property - def action_features(self) -> dict[str, type]: - return self._motors_ft - - @property - def is_connected(self) -> bool: - return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) - - def connect(self, calibrate: bool = True) -> None: - """ - We assume that at connection time, arm is in a rest position, - and torque can be safely disabled to run calibration. - """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - - self.bus.connect() - if not self.is_calibrated and calibrate: - logger.info( - "Mismatch between calibration values in the motor and the calibration file or no calibration file found" - ) - self.calibrate() - - for cam in self.cameras.values(): - cam.connect() - - self.configure() - logger.info(f"{self} connected.") - - @property - def is_calibrated(self) -> bool: - return self.bus.is_calibrated - - def calibrate(self) -> None: - if self.calibration: - # self.calibration is not empty here - user_input = input( - f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: " - ) - if user_input.strip().lower() != "c": - logger.info(f"Writing calibration file associated with the id {self.id} to the motors") - self.bus.write_calibration(self.calibration) - return - - logger.info(f"\nRunning calibration of {self}") - self.bus.disable_torque() - for motor in self.bus.motors: - self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value) - - input(f"Move {self} to the middle of its range of motion and press ENTER....") - homing_offsets = self.bus.set_half_turn_homings() - - print( - "Move all joints sequentially through their entire ranges " - "of motion.\nRecording positions. Press ENTER to stop..." - ) - range_mins, range_maxes = self.bus.record_ranges_of_motion() - - self.calibration = {} - for motor, m in self.bus.motors.items(): - self.calibration[motor] = MotorCalibration( - id=m.id, - drive_mode=0, - homing_offset=homing_offsets[motor], - range_min=range_mins[motor], - range_max=range_maxes[motor], - ) - - self.bus.write_calibration(self.calibration) - self._save_calibration() - print("Calibration saved to", self.calibration_fpath) - - def configure(self) -> None: - with self.bus.torque_disabled(): - self.bus.configure_motors() - for motor in self.bus.motors: - self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value) - # Set P_Coefficient to lower value to avoid shakiness (Default is 32) - self.bus.write("P_Coefficient", motor, 16) - # Set I_Coefficient and D_Coefficient to default value 0 and 32 - self.bus.write("I_Coefficient", motor, 0) - self.bus.write("D_Coefficient", motor, 32) - - if motor == "gripper": - self.bus.write( - "Max_Torque_Limit", motor, 500 - ) # 50% of the max torque limit to avoid burnout - self.bus.write("Protection_Current", motor, 250) # 50% of max current to avoid burnout - self.bus.write("Overload_Torque", motor, 25) # 25% torque when overloaded - - def setup_motors(self) -> None: - for motor in reversed(self.bus.motors): - input(f"Connect the controller board to the '{motor}' motor only and press enter.") - self.bus.setup_motor(motor) - print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") - - def get_observation(self) -> dict[str, Any]: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - - # Read arm position - start = time.perf_counter() - obs_dict = self.bus.sync_read("Present_Position") - obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()} - dt_ms = (time.perf_counter() - start) * 1e3 - logger.debug(f"{self} read state: {dt_ms:.1f}ms") - - # Capture images from cameras - for cam_key, cam in self.cameras.items(): - start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() - dt_ms = (time.perf_counter() - start) * 1e3 - logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") - - return obs_dict - - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: - """Command arm to move to a target joint configuration. - - The relative action magnitude may be clipped depending on the configuration parameter - `max_relative_target`. In this case, the action sent differs from original action. - Thus, this function always returns the action actually sent. - - Raises: - RobotDeviceNotConnectedError: if robot is not connected. - - Returns: - the action sent to the motors, potentially clipped. - """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - - goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} - - # Cap goal position when too far away from present position. - # /!\ Slower fps expected due to reading from the follower. - if self.config.max_relative_target is not None: - present_pos = self.bus.sync_read("Present_Position") - goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()} - goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target) - - # Send goal position to the arm - self.bus.sync_write("Goal_Position", goal_pos) - return {f"{motor}.pos": val for motor, val in goal_pos.items()} - - def disconnect(self): - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - - self.bus.disconnect(self.config.disable_torque_on_disconnect) - for cam in self.cameras.values(): - cam.disconnect() - - logger.info(f"{self} disconnected.") diff --git a/src/lerobot/robots/so_follower/__init__.py b/src/lerobot/robots/so_follower/__init__.py new file mode 100644 index 000000000..82755250c --- /dev/null +++ b/src/lerobot/robots/so_follower/__init__.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from .so100_follower.config_so100_follower import SO100FollowerConfig +from .so100_follower.so100_follower import SO100Follower +from .so101_follower.config_so101_follower import SO101FollowerConfig +from .so101_follower.so101_follower import SO101Follower +from .so_follower_base import SOFollowerBase +from .so_follower_config_base import SOFollowerConfigBase diff --git a/src/lerobot/robots/so100_follower/robot_kinematic_processor.py b/src/lerobot/robots/so_follower/robot_kinematic_processor.py similarity index 100% rename from src/lerobot/robots/so100_follower/robot_kinematic_processor.py rename to src/lerobot/robots/so_follower/robot_kinematic_processor.py diff --git a/src/lerobot/teleoperators/so101_leader/__init__.py b/src/lerobot/robots/so_follower/so100_follower/config_so100_follower.py similarity index 64% rename from src/lerobot/teleoperators/so101_leader/__init__.py rename to src/lerobot/robots/so_follower/so100_follower/config_so100_follower.py index 11e277c91..3e8dd7e9f 100644 --- a/src/lerobot/teleoperators/so101_leader/__init__.py +++ b/src/lerobot/robots/so_follower/so100_follower/config_so100_follower.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,5 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_so101_leader import SO101LeaderConfig -from .so101_leader import SO101Leader +from dataclasses import dataclass + +from ...config import RobotConfig +from ..so_follower_config_base import SOFollowerConfigBase + + +@RobotConfig.register_subclass("so100_follower") +@dataclass +class SO100FollowerConfig(SOFollowerConfigBase): + pass diff --git a/src/lerobot/robots/so_follower/so100_follower/so100.md b/src/lerobot/robots/so_follower/so100_follower/so100.md new file mode 120000 index 000000000..f06f88ff6 --- /dev/null +++ b/src/lerobot/robots/so_follower/so100_follower/so100.md @@ -0,0 +1 @@ +../../../../../docs/source/so100.mdx \ No newline at end of file diff --git a/src/lerobot/robots/so_follower/so100_follower/so100_follower.py b/src/lerobot/robots/so_follower/so100_follower/so100_follower.py new file mode 100644 index 000000000..ef61f5ce3 --- /dev/null +++ b/src/lerobot/robots/so_follower/so100_follower/so100_follower.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..so_follower_base import SOFollowerBase +from .config_so100_follower import SO100FollowerConfig + + +class SO100Follower(SOFollowerBase): + """ + SO-101 follower robot class. [SO-101 Follower Arm](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio + """ + + config_class = SO100FollowerConfig + name = "so100_follower" diff --git a/src/lerobot/teleoperators/so100_leader/__init__.py b/src/lerobot/robots/so_follower/so101_follower/config_so101_follower.py similarity index 64% rename from src/lerobot/teleoperators/so100_leader/__init__.py rename to src/lerobot/robots/so_follower/so101_follower/config_so101_follower.py index 747416be2..950e8c839 100644 --- a/src/lerobot/teleoperators/so100_leader/__init__.py +++ b/src/lerobot/robots/so_follower/so101_follower/config_so101_follower.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,5 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_so100_leader import SO100LeaderConfig -from .so100_leader import SO100Leader +from dataclasses import dataclass + +from ...config import RobotConfig +from ..so_follower_config_base import SOFollowerConfigBase + + +@RobotConfig.register_subclass("so101_follower") +@dataclass +class SO101FollowerConfig(SOFollowerConfigBase): + pass diff --git a/src/lerobot/robots/so_follower/so101_follower/so101.md b/src/lerobot/robots/so_follower/so101_follower/so101.md new file mode 120000 index 000000000..38f4deca7 --- /dev/null +++ b/src/lerobot/robots/so_follower/so101_follower/so101.md @@ -0,0 +1 @@ +../../../../../docs/source/so101.mdx \ No newline at end of file diff --git a/src/lerobot/robots/so101_follower/__init__.py b/src/lerobot/robots/so_follower/so101_follower/so101_follower.py similarity index 63% rename from src/lerobot/robots/so101_follower/__init__.py rename to src/lerobot/robots/so_follower/so101_follower/so101_follower.py index 9ff2baf45..b4cfb2711 100644 --- a/src/lerobot/robots/so101_follower/__init__.py +++ b/src/lerobot/robots/so_follower/so101_follower/so101_follower.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,5 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from ..so_follower_base import SOFollowerBase from .config_so101_follower import SO101FollowerConfig -from .so101_follower import SO101Follower + + +class SO101Follower(SOFollowerBase): + """ + SO-101 follower robot class. [SO-101 Follower Arm](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio + """ + + config_class = SO101FollowerConfig + name = "so101_follower" diff --git a/src/lerobot/robots/so100_follower/so100_follower.py b/src/lerobot/robots/so_follower/so_follower_base.py similarity index 93% rename from src/lerobot/robots/so100_follower/so100_follower.py rename to src/lerobot/robots/so_follower/so_follower_base.py index d660ebed4..5e6e32888 100644 --- a/src/lerobot/robots/so100_follower/so100_follower.py +++ b/src/lerobot/robots/so_follower/so_follower_base.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -29,22 +29,23 @@ from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnected from ..robot import Robot from ..utils import ensure_safe_goal_position -from .config_so100_follower import SO100FollowerConfig +from .so_follower_config_base import SOFollowerConfigBase logger = logging.getLogger(__name__) -class SO100Follower(Robot): +class SOFollowerBase(Robot): """ - [SO-100 Follower Arm](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio + Generic SO follower base implementing common functionality for SO-100/101/10X. + Designed to be subclassed with a per-hardware-model `config_class` and `name`. """ - config_class = SO100FollowerConfig - name = "so100_follower" + # `config_class` and `name` should be set by subclasses - def __init__(self, config: SO100FollowerConfig): + def __init__(self, config: SOFollowerConfigBase): super().__init__(config) self.config = config + # choose normalization mode depending on config if available norm_mode_body = MotorNormMode.DEGREES if config.use_degrees else MotorNormMode.RANGE_M100_100 self.bus = FeetechMotorsBus( port=self.config.port, @@ -126,6 +127,7 @@ class SO100Follower(Robot): input(f"Move {self} to the middle of its range of motion and press ENTER....") homing_offsets = self.bus.set_half_turn_homings() + # Attempt to call record_ranges_of_motion with a reduced motor set when appropriate. full_turn_motor = "wrist_roll" unknown_range_motors = [motor for motor in self.bus.motors if motor != full_turn_motor] print( diff --git a/src/lerobot/robots/so101_follower/config_so101_follower.py b/src/lerobot/robots/so_follower/so_follower_config_base.py similarity index 93% rename from src/lerobot/robots/so101_follower/config_so101_follower.py rename to src/lerobot/robots/so_follower/so_follower_config_base.py index 03c3530c2..6e42df278 100644 --- a/src/lerobot/robots/so101_follower/config_so101_follower.py +++ b/src/lerobot/robots/so_follower/so_follower_config_base.py @@ -21,9 +21,10 @@ from lerobot.cameras import CameraConfig from ..config import RobotConfig -@RobotConfig.register_subclass("so101_follower") @dataclass -class SO101FollowerConfig(RobotConfig): +class SOFollowerConfigBase(RobotConfig): + """Base configuration class for SO Follower robots.""" + # Port to connect to the arm port: str diff --git a/src/lerobot/robots/utils.py b/src/lerobot/robots/utils.py index 9c5043335..ad6cc3da1 100644 --- a/src/lerobot/robots/utils.py +++ b/src/lerobot/robots/utils.py @@ -33,11 +33,11 @@ def make_robot_from_config(config: RobotConfig) -> Robot: return OmxFollower(config) elif config.type == "so100_follower": - from .so100_follower import SO100Follower + from .so_follower import SO100Follower return SO100Follower(config) elif config.type == "so101_follower": - from .so101_follower import SO101Follower + from .so_follower import SO101Follower return SO101Follower(config) elif config.type == "lekiwi": diff --git a/src/lerobot/scripts/lerobot_calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py index 910a9a1b5..1468c6e93 100644 --- a/src/lerobot/scripts/lerobot_calibrate.py +++ b/src/lerobot/scripts/lerobot_calibrate.py @@ -41,8 +41,7 @@ from lerobot.robots import ( # noqa: F401 lekiwi, make_robot_from_config, omx_follower, - so100_follower, - so101_follower, + so_follower, ) from lerobot.teleoperators import ( # noqa: F401 Teleoperator, @@ -51,8 +50,7 @@ from lerobot.teleoperators import ( # noqa: F401 koch_leader, make_teleoperator_from_config, omx_leader, - so100_leader, - so101_leader, + so_leader, ) from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.utils import init_logging diff --git a/src/lerobot/scripts/lerobot_find_joint_limits.py b/src/lerobot/scripts/lerobot_find_joint_limits.py index f97c0d820..b36cdbc90 100644 --- a/src/lerobot/scripts/lerobot_find_joint_limits.py +++ b/src/lerobot/scripts/lerobot_find_joint_limits.py @@ -47,8 +47,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, make_robot_from_config, omx_follower, - so100_follower, - so101_follower, + so_follower, ) from lerobot.teleoperators import ( # noqa: F401 TeleoperatorConfig, @@ -56,8 +55,7 @@ from lerobot.teleoperators import ( # noqa: F401 koch_leader, make_teleoperator_from_config, omx_leader, - so100_leader, - so101_leader, + so_leader, ) from lerobot.utils.robot_utils import precise_sleep diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index ced6b307f..5d2945e67 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -98,8 +98,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, make_robot_from_config, omx_follower, - so100_follower, - so101_follower, + so_follower, ) from lerobot.teleoperators import ( # noqa: F401 Teleoperator, @@ -109,8 +108,7 @@ from lerobot.teleoperators import ( # noqa: F401 koch_leader, make_teleoperator_from_config, omx_leader, - so100_leader, - so101_leader, + so_leader, ) from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop from lerobot.utils.constants import ACTION, OBS_STR @@ -282,8 +280,8 @@ def record_loop( if isinstance( t, ( - so100_leader.SO100Leader - | so101_leader.SO101Leader + so_leader.SO100Leader + | so_leader.SO101Leader | koch_leader.KochLeader | omx_leader.OmxLeader ), diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index 5cde4251c..af7c63365 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -59,8 +59,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, make_robot_from_config, omx_follower, - so100_follower, - so101_follower, + so_follower, ) from lerobot.utils.constants import ACTION from lerobot.utils.import_utils import register_third_party_plugins diff --git a/src/lerobot/scripts/lerobot_setup_motors.py b/src/lerobot/scripts/lerobot_setup_motors.py index b721e55ca..ea5c821d1 100644 --- a/src/lerobot/scripts/lerobot_setup_motors.py +++ b/src/lerobot/scripts/lerobot_setup_motors.py @@ -34,16 +34,14 @@ from lerobot.robots import ( # noqa: F401 lekiwi, make_robot_from_config, omx_follower, - so100_follower, - so101_follower, + so_follower, ) from lerobot.teleoperators import ( # noqa: F401 TeleoperatorConfig, koch_leader, make_teleoperator_from_config, omx_leader, - so100_leader, - so101_leader, + so_leader, ) COMPATIBLE_DEVICES = [ diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index 95025891e..2e0724574 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -76,8 +76,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, make_robot_from_config, omx_follower, - so100_follower, - so101_follower, + so_follower, ) from lerobot.teleoperators import ( # noqa: F401 Teleoperator, @@ -89,8 +88,7 @@ from lerobot.teleoperators import ( # noqa: F401 koch_leader, make_teleoperator_from_config, omx_leader, - so100_leader, - so101_leader, + so_leader, ) from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.robot_utils import precise_sleep diff --git a/src/lerobot/teleoperators/bi_so100_leader/bi_so100_leader.py b/src/lerobot/teleoperators/bi_so100_leader/bi_so100_leader.py index 769669655..93f66eb2e 100644 --- a/src/lerobot/teleoperators/bi_so100_leader/bi_so100_leader.py +++ b/src/lerobot/teleoperators/bi_so100_leader/bi_so100_leader.py @@ -17,8 +17,7 @@ import logging from functools import cached_property -from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig -from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader +from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig from ..teleoperator import Teleoperator from .config_bi_so100_leader import BiSO100LeaderConfig diff --git a/src/lerobot/teleoperators/so100_leader/so100_leader.py b/src/lerobot/teleoperators/so100_leader/so100_leader.py deleted file mode 100644 index edcfe53e6..000000000 --- a/src/lerobot/teleoperators/so100_leader/so100_leader.py +++ /dev/null @@ -1,159 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import time - -from lerobot.motors import Motor, MotorCalibration, MotorNormMode -from lerobot.motors.feetech import ( - FeetechMotorsBus, - OperatingMode, -) -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError - -from ..teleoperator import Teleoperator -from .config_so100_leader import SO100LeaderConfig - -logger = logging.getLogger(__name__) - - -class SO100Leader(Teleoperator): - """ - [SO-100 Leader Arm](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio - """ - - config_class = SO100LeaderConfig - name = "so100_leader" - - def __init__(self, config: SO100LeaderConfig): - super().__init__(config) - self.config = config - self.bus = FeetechMotorsBus( - port=self.config.port, - motors={ - "shoulder_pan": Motor(1, "sts3215", MotorNormMode.RANGE_M100_100), - "shoulder_lift": Motor(2, "sts3215", MotorNormMode.RANGE_M100_100), - "elbow_flex": Motor(3, "sts3215", MotorNormMode.RANGE_M100_100), - "wrist_flex": Motor(4, "sts3215", MotorNormMode.RANGE_M100_100), - "wrist_roll": Motor(5, "sts3215", MotorNormMode.RANGE_M100_100), - "gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100), - }, - calibration=self.calibration, - ) - - @property - def action_features(self) -> dict[str, type]: - return {f"{motor}.pos": float for motor in self.bus.motors} - - @property - def feedback_features(self) -> dict[str, type]: - return {} - - @property - def is_connected(self) -> bool: - return self.bus.is_connected - - def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - - self.bus.connect() - if not self.is_calibrated and calibrate: - logger.info( - "Mismatch between calibration values in the motor and the calibration file or no calibration file found" - ) - self.calibrate() - - self.configure() - logger.info(f"{self} connected.") - - @property - def is_calibrated(self) -> bool: - return self.bus.is_calibrated - - def calibrate(self) -> None: - if self.calibration: - # Calibration file exists, ask user whether to use it or run new calibration - user_input = input( - f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: " - ) - if user_input.strip().lower() != "c": - logger.info(f"Writing calibration file associated with the id {self.id} to the motors") - self.bus.write_calibration(self.calibration) - return - - logger.info(f"\nRunning calibration of {self}") - self.bus.disable_torque() - for motor in self.bus.motors: - self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value) - - input(f"Move {self} to the middle of its range of motion and press ENTER....") - homing_offsets = self.bus.set_half_turn_homings() - - full_turn_motor = "wrist_roll" - unknown_range_motors = [motor for motor in self.bus.motors if motor != full_turn_motor] - print( - f"Move all joints except '{full_turn_motor}' sequentially through their " - "entire ranges of motion.\nRecording positions. Press ENTER to stop..." - ) - range_mins, range_maxes = self.bus.record_ranges_of_motion(unknown_range_motors) - range_mins[full_turn_motor] = 0 - range_maxes[full_turn_motor] = 4095 - - self.calibration = {} - for motor, m in self.bus.motors.items(): - self.calibration[motor] = MotorCalibration( - id=m.id, - drive_mode=0, - homing_offset=homing_offsets[motor], - range_min=range_mins[motor], - range_max=range_maxes[motor], - ) - - self.bus.write_calibration(self.calibration) - self._save_calibration() - print(f"Calibration saved to {self.calibration_fpath}") - - def configure(self) -> None: - self.bus.disable_torque() - self.bus.configure_motors() - for motor in self.bus.motors: - self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value) - - def setup_motors(self) -> None: - for motor in reversed(self.bus.motors): - input(f"Connect the controller board to the '{motor}' motor only and press enter.") - self.bus.setup_motor(motor) - print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") - - def get_action(self) -> dict[str, float]: - start = time.perf_counter() - action = self.bus.sync_read("Present_Position") - action = {f"{motor}.pos": val for motor, val in action.items()} - dt_ms = (time.perf_counter() - start) * 1e3 - logger.debug(f"{self} read action: {dt_ms:.1f}ms") - return action - - def send_feedback(self, feedback: dict[str, float]) -> None: - # TODO(rcadene, aliberts): Implement force feedback - raise NotImplementedError - - def disconnect(self) -> None: - if not self.is_connected: - DeviceNotConnectedError(f"{self} is not connected.") - - self.bus.disconnect() - logger.info(f"{self} disconnected.") diff --git a/src/lerobot/teleoperators/so_leader/__init__.py b/src/lerobot/teleoperators/so_leader/__init__.py new file mode 100644 index 000000000..a1017f3b9 --- /dev/null +++ b/src/lerobot/teleoperators/so_leader/__init__.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .so100_leader.config_so100_leader import SO100LeaderConfig +from .so100_leader.so100_leader import SO100Leader +from .so101_leader.config_so101_leader import SO101LeaderConfig +from .so101_leader.so101_leader import SO101Leader +from .so_leader_base import SOLeaderBase +from .so_leader_config_base import SOLeaderConfigBase diff --git a/src/lerobot/teleoperators/so100_leader/config_so100_leader.py b/src/lerobot/teleoperators/so_leader/so100_leader/config_so100_leader.py similarity index 83% rename from src/lerobot/teleoperators/so100_leader/config_so100_leader.py rename to src/lerobot/teleoperators/so_leader/so100_leader/config_so100_leader.py index a97949b7e..092eb0cfc 100644 --- a/src/lerobot/teleoperators/so100_leader/config_so100_leader.py +++ b/src/lerobot/teleoperators/so_leader/so100_leader/config_so100_leader.py @@ -16,11 +16,11 @@ from dataclasses import dataclass -from ..config import TeleoperatorConfig +from ...config import TeleoperatorConfig +from ..so_leader_config_base import SOLeaderConfigBase @TeleoperatorConfig.register_subclass("so100_leader") @dataclass -class SO100LeaderConfig(TeleoperatorConfig): - # Port to connect to the arm - port: str +class SO100LeaderConfig(SOLeaderConfigBase): + pass diff --git a/src/lerobot/teleoperators/so_leader/so100_leader/so100.md b/src/lerobot/teleoperators/so_leader/so100_leader/so100.md new file mode 120000 index 000000000..f06f88ff6 --- /dev/null +++ b/src/lerobot/teleoperators/so_leader/so100_leader/so100.md @@ -0,0 +1 @@ +../../../../../docs/source/so100.mdx \ No newline at end of file diff --git a/src/lerobot/teleoperators/so_leader/so100_leader/so100_leader.py b/src/lerobot/teleoperators/so_leader/so100_leader/so100_leader.py new file mode 100644 index 000000000..530e4f723 --- /dev/null +++ b/src/lerobot/teleoperators/so_leader/so100_leader/so100_leader.py @@ -0,0 +1,27 @@ +# !/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..so_leader_base import SOLeaderBase +from .config_so100_leader import SO100LeaderConfig + + +class SO100Leader(SOLeaderBase): + """ + SO-101 leader robot class. [SO-101 Leader Arm](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio + """ + + config_class = SO100LeaderConfig + name = "so100_leader" diff --git a/src/lerobot/teleoperators/so101_leader/config_so101_leader.py b/src/lerobot/teleoperators/so_leader/so101_leader/config_so101_leader.py similarity index 81% rename from src/lerobot/teleoperators/so101_leader/config_so101_leader.py rename to src/lerobot/teleoperators/so_leader/so101_leader/config_so101_leader.py index 8d91c32df..1a6bb0df4 100644 --- a/src/lerobot/teleoperators/so101_leader/config_so101_leader.py +++ b/src/lerobot/teleoperators/so_leader/so101_leader/config_so101_leader.py @@ -16,13 +16,11 @@ from dataclasses import dataclass -from ..config import TeleoperatorConfig +from ...config import TeleoperatorConfig +from ..so_leader_config_base import SOLeaderConfigBase @TeleoperatorConfig.register_subclass("so101_leader") @dataclass -class SO101LeaderConfig(TeleoperatorConfig): - # Port to connect to the arm - port: str - - use_degrees: bool = False +class SO101LeaderConfig(SOLeaderConfigBase): + pass diff --git a/src/lerobot/teleoperators/so_leader/so101_leader/so101.md b/src/lerobot/teleoperators/so_leader/so101_leader/so101.md new file mode 120000 index 000000000..38f4deca7 --- /dev/null +++ b/src/lerobot/teleoperators/so_leader/so101_leader/so101.md @@ -0,0 +1 @@ +../../../../../docs/source/so101.mdx \ No newline at end of file diff --git a/src/lerobot/teleoperators/so_leader/so101_leader/so101_leader.py b/src/lerobot/teleoperators/so_leader/so101_leader/so101_leader.py new file mode 100644 index 000000000..3aed1f6f9 --- /dev/null +++ b/src/lerobot/teleoperators/so_leader/so101_leader/so101_leader.py @@ -0,0 +1,27 @@ +# !/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..so_leader_base import SOLeaderBase +from .config_so101_leader import SO101LeaderConfig + + +class SO101Leader(SOLeaderBase): + """ + SO-101 leader robot class. [SO-101 Leader Arm](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio + """ + + config_class = SO101LeaderConfig + name = "so101_leader" diff --git a/src/lerobot/teleoperators/so101_leader/so101_leader.py b/src/lerobot/teleoperators/so_leader/so_leader_base.py similarity index 87% rename from src/lerobot/teleoperators/so101_leader/so101_leader.py rename to src/lerobot/teleoperators/so_leader/so_leader_base.py index be804bf70..6cdaca2e7 100644 --- a/src/lerobot/teleoperators/so101_leader/so101_leader.py +++ b/src/lerobot/teleoperators/so_leader/so_leader_base.py @@ -1,6 +1,6 @@ -#!/usr/bin/env python +# !/usr/bin/env python -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,20 +25,15 @@ from lerobot.motors.feetech import ( from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..teleoperator import Teleoperator -from .config_so101_leader import SO101LeaderConfig +from .so_leader_config_base import SOLeaderConfigBase logger = logging.getLogger(__name__) -class SO101Leader(Teleoperator): - """ - SO-101 Leader Arm designed by TheRobotStudio and Hugging Face. - """ +class SOLeaderBase(Teleoperator): + """Generic SO leader base for SO-100/101/10X teleoperators.""" - config_class = SO101LeaderConfig - name = "so101_leader" - - def __init__(self, config: SO101LeaderConfig): + def __init__(self, config: SOLeaderConfigBase): super().__init__(config) self.config = config norm_mode_body = MotorNormMode.DEGREES if config.use_degrees else MotorNormMode.RANGE_M100_100 @@ -104,11 +99,15 @@ class SO101Leader(Teleoperator): input(f"Move {self} to the middle of its range of motion and press ENTER....") homing_offsets = self.bus.set_half_turn_homings() + full_turn_motor = "wrist_roll" + unknown_range_motors = [motor for motor in self.bus.motors if motor != full_turn_motor] print( - "Move all joints sequentially through their entire ranges " - "of motion.\nRecording positions. Press ENTER to stop..." + f"Move all joints except '{full_turn_motor}' sequentially through their " + "entire ranges of motion.\nRecording positions. Press ENTER to stop..." ) - range_mins, range_maxes = self.bus.record_ranges_of_motion() + range_mins, range_maxes = self.bus.record_ranges_of_motion(unknown_range_motors) + range_mins[full_turn_motor] = 0 + range_maxes[full_turn_motor] = 4095 self.calibration = {} for motor, m in self.bus.motors.items(): @@ -145,7 +144,7 @@ class SO101Leader(Teleoperator): return action def send_feedback(self, feedback: dict[str, float]) -> None: - # TODO(rcadene, aliberts): Implement force feedback + # TODO: Implement force feedback raise NotImplementedError def disconnect(self) -> None: diff --git a/src/lerobot/robots/so100_follower/__init__.py b/src/lerobot/teleoperators/so_leader/so_leader_config_base.py similarity index 66% rename from src/lerobot/robots/so100_follower/__init__.py rename to src/lerobot/teleoperators/so_leader/so_leader_config_base.py index 5dc43ac3b..b2f2dfcfa 100644 --- a/src/lerobot/robots/so100_follower/__init__.py +++ b/src/lerobot/teleoperators/so_leader/so_leader_config_base.py @@ -14,5 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_so100_follower import SO100FollowerConfig -from .so100_follower import SO100Follower +from dataclasses import dataclass + +from ..config import TeleoperatorConfig + + +@dataclass +class SOLeaderConfigBase(TeleoperatorConfig): + """Base configuration class for SO Leader teleoperators.""" + + # Port to connect to the arm + port: str + + # Whether to use degrees for angles + use_degrees: bool = False diff --git a/src/lerobot/teleoperators/utils.py b/src/lerobot/teleoperators/utils.py index 699d1253f..74e43ec95 100644 --- a/src/lerobot/teleoperators/utils.py +++ b/src/lerobot/teleoperators/utils.py @@ -46,11 +46,11 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator: return OmxLeader(config) elif config.type == "so100_leader": - from .so100_leader import SO100Leader + from .so_leader import SO100Leader return SO100Leader(config) elif config.type == "so101_leader": - from .so101_leader import SO101Leader + from .so_leader import SO101Leader return SO101Leader(config) elif config.type == "mock_teleop": diff --git a/tests/robots/test_so100_follower.py b/tests/robots/test_so100_follower.py index d76b9591a..fc300b820 100644 --- a/tests/robots/test_so100_follower.py +++ b/tests/robots/test_so100_follower.py @@ -19,7 +19,7 @@ from unittest.mock import MagicMock, patch import pytest -from lerobot.robots.so100_follower import ( +from lerobot.robots.so_follower import ( SO100Follower, SO100FollowerConfig, ) @@ -66,7 +66,7 @@ def follower(): with ( patch( - "lerobot.robots.so100_follower.so100_follower.FeetechMotorsBus", + "lerobot.robots.so_follower.so_follower_base.FeetechMotorsBus", side_effect=_bus_side_effect, ), patch.object(SO100Follower, "configure", lambda self: None), From 242b65d2df4e7cecf2a87701c5a0ee2736d3ac63 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 8 Jan 2026 14:45:07 +0100 Subject: [PATCH 025/111] chore(docs): update code block syntax to specify python for clarity (#2770) --- docs/source/phone_teleop.mdx | 16 ++++++++-------- docs/source/processors_robots_teleop.mdx | 6 +++--- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/docs/source/phone_teleop.mdx b/docs/source/phone_teleop.mdx index 76e3c367c..06e524975 100644 --- a/docs/source/phone_teleop.mdx +++ b/docs/source/phone_teleop.mdx @@ -44,7 +44,7 @@ Modify the examples to use `PhoneOS.IOS` or `PhoneOS.ANDROID` in `PhoneConfig`. Teleoperation example: -```36:43:examples/phone_so100_teleop.py +```python from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID @@ -103,7 +103,7 @@ Additionally you can customize mapping or safety limits by editing the processor - Kinematics are used in multiple steps. We use [Placo](https://github.com/Rhoban/placo) which is a wrapper around Pinocchio for handling our kinematics. We construct the kinematics object by passing the robot's URDF and target frame. We set `target_frame_name` to the gripper frame. - ```examples/phone_to_so100/teleoperate.py + ```python kinematics_solver = RobotKinematics( urdf_path="./SO101/so101_new_calib.urdf", target_frame_name="gripper_frame_link", @@ -114,7 +114,7 @@ Additionally you can customize mapping or safety limits by editing the processor - The `MapPhoneActionToRobotAction` step converts the calibrated phone pose and inputs into target deltas and gripper commands, below is shown what the step outputs. - ```src/lerobot/teleoperators/phone/phone_processor.py + ```python action["enabled"] = enabled action["target_x"] = -pos[1] if enabled else 0.0 action["target_y"] = pos[0] if enabled else 0.0 @@ -127,7 +127,7 @@ Additionally you can customize mapping or safety limits by editing the processor - The `EEReferenceAndDelta` step converts target deltas to an absolute desired EE pose, storing a reference on enable, the `end_effector_step_sizes` are the step sizes for the EE pose and can be modified to change the motion speed. - ```examples/phone_to_so100/teleoperate.py + ```python EEReferenceAndDelta( kinematics=kinematics_solver, end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5}, @@ -138,7 +138,7 @@ Additionally you can customize mapping or safety limits by editing the processor - The `EEBoundsAndSafety` step clamps EE motion to a workspace and checks for large ee step jumps to ensure safety. The `end_effector_bounds` are the bounds for the EE pose and can be modified to change the workspace. The `max_ee_step_m` are the step limits for the EE pose and can be modified to change the safety limits. - ```examples/phone_to_so100/teleoperate.py + ```python EEBoundsAndSafety( end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, max_ee_step_m=0.10, @@ -147,7 +147,7 @@ Additionally you can customize mapping or safety limits by editing the processor - The `GripperVelocityToJoint` step turns a velocity‑like gripper input into absolute gripper position using the current measured state. The `speed_factor` is the factor by which the velocity is multiplied. - ```examples/phone_to_so100/teleoperate.py + ```python GripperVelocityToJoint(speed_factor=20.0) ``` @@ -157,7 +157,7 @@ We use different IK initial guesses in the kinematic steps. As initial guess eit - Closed loop (used in record/eval): sets `initial_guess_current_joints=True` so IK starts from the measured joints each frame. - ```examples/phone_to_so100/record.py + ```python InverseKinematicsEEToJoints( kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()), @@ -167,7 +167,7 @@ We use different IK initial guesses in the kinematic steps. As initial guess eit - Open loop (used in replay): sets `initial_guess_current_joints=False` so IK continues from the previous IK solution rather than the measured state. This preserves action stability when we replay without feedback. - ```examples/phone_to_so100/replay.py + ```python InverseKinematicsEEToJoints( kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()), diff --git a/docs/source/processors_robots_teleop.mdx b/docs/source/processors_robots_teleop.mdx index 3d8dcb409..093a8e0e3 100644 --- a/docs/source/processors_robots_teleop.mdx +++ b/docs/source/processors_robots_teleop.mdx @@ -30,7 +30,7 @@ Each of these pipelines handle different conversions between different action an Below is an example of the three pipelines that we use in the phone to SO-100 follower examples: -```69:90:examples/phone_so100_record.py +```python phone_to_robot_ee_pose_processor = RobotProcessorPipeline[RobotAction, RobotAction]( # teleop -> dataset action steps=[ MapPhoneActionToRobotAction(platform=teleop_config.phone_os), @@ -84,7 +84,7 @@ Dataset features are determined by the keys saved in the dataset. Each step can Below is and example of how we declare features with the `transform_features` method in the phone to SO-100 follower examples: -```src/lerobot/robots/so100_follower/robot_kinematic_processor.py +```python def transform_features( self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: @@ -103,7 +103,7 @@ Here we declare what PolicyFeatures we modify in this step, so we know what feat Below is an example of how we aggregate and merge features in the phone to SO-100 record example: -```121:145:examples/phone_so100_record.py +```python features=combine_feature_dicts( # Run the feature contract of the pipelines # This tells you how the features would look like after the pipeline steps From 8b6fc0ae05109963e387674187d7dee4fa7d6833 Mon Sep 17 00:00:00 2001 From: Leo Tronchon Date: Thu, 8 Jan 2026 18:06:39 +0100 Subject: [PATCH 026/111] feat(datasets): expose video codec option for dataset recording (#2771) * expose codec options + add tests * pre-commit run -a --- src/lerobot/datasets/lerobot_dataset.py | 21 ++++- src/lerobot/scripts/lerobot_record.py | 7 ++ tests/datasets/test_datasets.py | 100 ++++++++++++++++++++++++ 3 files changed, 125 insertions(+), 3 deletions(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index d9d4b22d0..d9cc28b30 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -78,6 +78,7 @@ from lerobot.datasets.video_utils import ( from lerobot.utils.constants import HF_LEROBOT_HOME CODEBASE_VERSION = "v3.0" +VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1"} class LeRobotDatasetMetadata: @@ -540,11 +541,13 @@ class LeRobotDatasetMetadata: return obj -def _encode_video_worker(video_key: str, episode_index: int, root: Path, fps: int) -> Path: +def _encode_video_worker( + video_key: str, episode_index: int, root: Path, fps: int, vcodec: str = "libsvtav1" +) -> Path: temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4" fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0) img_dir = (root / fpath).parent - encode_video_frames(img_dir, temp_path, fps, overwrite=True) + encode_video_frames(img_dir, temp_path, fps, vcodec=vcodec, overwrite=True) shutil.rmtree(img_dir) return temp_path @@ -563,6 +566,7 @@ class LeRobotDataset(torch.utils.data.Dataset): download_videos: bool = True, video_backend: str | None = None, batch_encoding_size: int = 1, + vcodec: str = "libsvtav1", ): """ 2 modes are available for instantiating this class, depending on 2 different use cases: @@ -675,8 +679,13 @@ class LeRobotDataset(torch.utils.data.Dataset): You can also use the 'pyav' decoder used by Torchvision, which used to be the default option, or 'video_reader' which is another decoder of Torchvision. batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos. Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1. + vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc', + 'libsvtav1'. Defaults to 'libsvtav1'. Use 'h264' for faster encoding on systems where AV1 + encoding is CPU-heavy. """ super().__init__() + if vcodec not in VALID_VIDEO_CODECS: + raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}") self.repo_id = repo_id self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id self.image_transforms = image_transforms @@ -688,6 +697,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.delta_indices = None self.batch_encoding_size = batch_encoding_size self.episodes_since_last_encoding = 0 + self.vcodec = vcodec # Unused attributes self.image_writer = None @@ -1211,6 +1221,7 @@ class LeRobotDataset(torch.utils.data.Dataset): episode_index, self.root, self.fps, + self.vcodec, ): video_key for video_key in self.meta.video_keys } @@ -1526,7 +1537,7 @@ class LeRobotDataset(torch.utils.data.Dataset): Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, since video encoding with ffmpeg is already using multithreading. """ - return _encode_video_worker(video_key, episode_index, self.root, self.fps) + return _encode_video_worker(video_key, episode_index, self.root, self.fps, self.vcodec) @classmethod def create( @@ -1542,8 +1553,11 @@ class LeRobotDataset(torch.utils.data.Dataset): image_writer_threads: int = 0, video_backend: str | None = None, batch_encoding_size: int = 1, + vcodec: str = "libsvtav1", ) -> "LeRobotDataset": """Create a LeRobot Dataset from scratch in order to record data.""" + if vcodec not in VALID_VIDEO_CODECS: + raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}") obj = cls.__new__(cls) obj.meta = LeRobotDatasetMetadata.create( repo_id=repo_id, @@ -1560,6 +1574,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.image_writer = None obj.batch_encoding_size = batch_encoding_size obj.episodes_since_last_encoding = 0 + obj.vcodec = vcodec if image_writer_processes or image_writer_threads: obj.start_image_writer(image_writer_processes, image_writer_threads) diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 5d2945e67..a81a5d54e 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -27,6 +27,8 @@ lerobot-record \ --dataset.num_episodes=2 \ --dataset.single_task="Grab the cube" \ --display_data=true + # <- Optional: specify video codec (h264, hevc, libsvtav1). Default is libsvtav1. \ + # --dataset.vcodec=h264 \ # <- Teleop optional if you want to teleoperate to record or in between episodes with a policy \ # --teleop.type=so100_leader \ # --teleop.port=/dev/tty.usbmodem58760431551 \ @@ -165,6 +167,9 @@ class DatasetRecordConfig: # Number of episodes to record before batch encoding videos # Set to 1 for immediate encoding (default behavior), or higher for batched encoding video_encoding_batch_size: int = 1 + # Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1'. + # Use 'h264' for faster encoding on systems where AV1 encoding is CPU-heavy. + vcodec: str = "libsvtav1" # Rename map for the observation to override the image and state keys rename_map: dict[str, str] = field(default_factory=dict) @@ -427,6 +432,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset: cfg.dataset.repo_id, root=cfg.dataset.root, batch_encoding_size=cfg.dataset.video_encoding_batch_size, + vcodec=cfg.dataset.vcodec, ) if hasattr(robot, "cameras") and len(robot.cameras) > 0: @@ -448,6 +454,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset: image_writer_processes=cfg.dataset.num_image_writer_processes, image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), batch_encoding_size=cfg.dataset.video_encoding_batch_size, + vcodec=cfg.dataset.vcodec, ) # Load pretrained policy diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 38fdc358d..4c91c55c0 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -31,8 +31,10 @@ from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets.factory import make_dataset from lerobot.datasets.image_writer import image_array_to_pil_image from lerobot.datasets.lerobot_dataset import ( + VALID_VIDEO_CODECS, LeRobotDataset, MultiLeRobotDataset, + _encode_video_worker, ) from lerobot.datasets.utils import ( DEFAULT_CHUNK_SIZE, @@ -1292,3 +1294,101 @@ def test_frames_in_current_file_calculation(tmp_path, empty_lerobot_dataset_fact frame = loaded_dataset[idx] expected_ep = idx // frames_per_episode assert frame["episode_index"].item() == expected_ep + + +def test_encode_video_worker_forwards_vcodec(tmp_path): + """Test that _encode_video_worker correctly forwards the vcodec parameter to encode_video_frames.""" + from unittest.mock import patch + + from lerobot.datasets.utils import DEFAULT_IMAGE_PATH + + # Create the expected directory structure + video_key = "observation.images.laptop" + episode_index = 0 + frame_index = 0 + + fpath = DEFAULT_IMAGE_PATH.format( + image_key=video_key, episode_index=episode_index, frame_index=frame_index + ) + img_dir = tmp_path / Path(fpath).parent + img_dir.mkdir(parents=True, exist_ok=True) + + # Create a dummy image file + dummy_img = Image.new("RGB", (64, 64), color="red") + dummy_img.save(img_dir / "frame-000000.png") + + # Track what vcodec was passed to encode_video_frames + captured_kwargs = {} + + def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs): + captured_kwargs.update(kwargs) + # Create a dummy output file so the worker doesn't fail + Path(video_path).parent.mkdir(parents=True, exist_ok=True) + Path(video_path).touch() + + with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames): + # Test with h264 codec + _encode_video_worker(video_key, episode_index, tmp_path, fps=30, vcodec="h264") + + assert "vcodec" in captured_kwargs + assert captured_kwargs["vcodec"] == "h264" + + +def test_encode_video_worker_default_vcodec(tmp_path): + """Test that _encode_video_worker uses libsvtav1 as the default codec.""" + from unittest.mock import patch + + from lerobot.datasets.utils import DEFAULT_IMAGE_PATH + + # Create the expected directory structure + video_key = "observation.images.laptop" + episode_index = 0 + frame_index = 0 + + fpath = DEFAULT_IMAGE_PATH.format( + image_key=video_key, episode_index=episode_index, frame_index=frame_index + ) + img_dir = tmp_path / Path(fpath).parent + img_dir.mkdir(parents=True, exist_ok=True) + + # Create a dummy image file + dummy_img = Image.new("RGB", (64, 64), color="red") + dummy_img.save(img_dir / "frame-000000.png") + + # Track what vcodec was passed to encode_video_frames + captured_kwargs = {} + + def mock_encode_video_frames(imgs_dir, video_path, fps, **kwargs): + captured_kwargs.update(kwargs) + # Create a dummy output file so the worker doesn't fail + Path(video_path).parent.mkdir(parents=True, exist_ok=True) + Path(video_path).touch() + + with patch("lerobot.datasets.lerobot_dataset.encode_video_frames", side_effect=mock_encode_video_frames): + # Test with default codec (no vcodec specified) + _encode_video_worker(video_key, episode_index, tmp_path, fps=30) + + assert "vcodec" in captured_kwargs + assert captured_kwargs["vcodec"] == "libsvtav1" + + +def test_lerobot_dataset_vcodec_validation(): + """Test that LeRobotDataset validates the vcodec parameter.""" + # Test that invalid vcodec raises ValueError + with pytest.raises(ValueError, match="Invalid vcodec"): + LeRobotDataset.__new__(LeRobotDataset) # bypass __init__ to test validation directly + # Actually test via create since it's easier + LeRobotDataset.create( + repo_id="test/invalid_codec", + fps=30, + features={"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}}, + vcodec="invalid_codec", + ) + + +def test_valid_video_codecs_constant(): + """Test that VALID_VIDEO_CODECS contains the expected codecs.""" + assert "h264" in VALID_VIDEO_CODECS + assert "hevc" in VALID_VIDEO_CODECS + assert "libsvtav1" in VALID_VIDEO_CODECS + assert len(VALID_VIDEO_CODECS) == 3 From ba3d2148a3b690db1b7243471922a261e39cbb0c Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Fri, 9 Jan 2026 19:10:02 +0100 Subject: [PATCH 027/111] skip peft cmd test in cli (#2776) * skip peft cmd test in cli * pre commit * update desc --- tests/{test_cli.py => test_cli_peft.py} | 6 ++++++ 1 file changed, 6 insertions(+) rename tests/{test_cli.py => test_cli_peft.py} (97%) diff --git a/tests/test_cli.py b/tests/test_cli_peft.py similarity index 97% rename from tests/test_cli.py rename to tests/test_cli_peft.py index a3472b48e..42fef4741 100644 --- a/tests/test_cli.py +++ b/tests/test_cli_peft.py @@ -7,6 +7,12 @@ from safetensors.torch import load_file from .utils import require_package +# Skip this entire module in CI +pytestmark = pytest.mark.skipif( + os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", + reason="This test requires peft and is very slow, not meant for CI", +) + def run_command(cmd, module, args): module = importlib.import_module(f"lerobot.scripts.{module}") From 1d86c9b7f2fc0bbf96db004299cb46161275327f Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Fri, 9 Jan 2026 23:08:37 +0100 Subject: [PATCH 028/111] feat(policies): add autoregressive VLAs with tokenization PiFast (#2734) --- docs/source/_toctree.yml | 2 + docs/source/pi0fast.mdx | 182 +++ pyproject.toml | 2 +- src/lerobot/policies/__init__.py | 2 + src/lerobot/policies/factory.py | 4 + src/lerobot/policies/pi0_fast/__init__.py | 21 + .../pi0_fast/configuration_pi0_fast.py | 161 ++ .../policies/pi0_fast/modeling_pi0_fast.py | 1353 +++++++++++++++++ .../policies/pi0_fast/processor_pi0_fast.py | 177 +++ .../policies/pi0_fast/train_fast_tokenizer.py | 539 +++++++ src/lerobot/processor/__init__.py | 3 +- src/lerobot/processor/tokenizer_processor.py | 266 +++- src/lerobot/utils/constants.py | 2 + src/lerobot/utils/import_utils.py | 1 + .../test_pi0_fast_original_vs_lerobot.py | 504 ++++++ 15 files changed, 3214 insertions(+), 5 deletions(-) create mode 100644 docs/source/pi0fast.mdx create mode 100644 src/lerobot/policies/pi0_fast/__init__.py create mode 100644 src/lerobot/policies/pi0_fast/configuration_pi0_fast.py create mode 100644 src/lerobot/policies/pi0_fast/modeling_pi0_fast.py create mode 100644 src/lerobot/policies/pi0_fast/processor_pi0_fast.py create mode 100644 src/lerobot/policies/pi0_fast/train_fast_tokenizer.py create mode 100644 tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 381e95dc4..2b8086cd7 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -37,6 +37,8 @@ title: SmolVLA - local: pi0 title: π₀ (Pi0) + - local: pi0fast + title: π₀-FAST (Pi0Fast) - local: pi05 title: π₀.₅ (Pi05) - local: groot diff --git a/docs/source/pi0fast.mdx b/docs/source/pi0fast.mdx new file mode 100644 index 000000000..e64355765 --- /dev/null +++ b/docs/source/pi0fast.mdx @@ -0,0 +1,182 @@ +# π₀-FAST (Pi0-FAST) + +π₀-FAST is a **Vision-Language-Action model for general robot control** that uses autoregressive next-token prediction to model continuous robot actions. + +## Model Overview + +π₀-FAST combines the power of Vision-Language Models with a novel action tokenization approach called **FAST (Frequency-space Action Sequence Tokenization)**. This enables training autoregressive VLAs on highly dexterous tasks that are impossible with standard binning-based discretization, while training **up to 5x faster** than diffusion-based approaches like π₀. + +### Why FAST? + +Standard approaches for robot action tokenization use simple per-dimension, per-timestep binning schemes. While passable for simple behaviors, this rapidly breaks down for complex and dexterous skills that require precision and high-frequency control. + +FAST solves this by compressing action sequences using signal processing techniques, resulting in a dense sequence of action tokens that can be predicted autoregressively—just like language tokens. + +### How FAST Tokenization Works + +The FAST tokenizer compresses action sequences through the following steps: + +1. **Normalize**: Take a continuous action chunk of shape `(H, D)` where `H` is the horizon and `D` is the action dimension. Normalize using one of the supported normalization methods (Quantiles recommended to handle outliers). + +2. **Discrete Cosine Transform (DCT)**: Apply DCT (via scipy) to each action dimension separately. DCT is a compression algorithm commonly used in image and audio codecs (JPEG, MP3). + +3. **Quantization**: Round and remove insignificant coefficients for each action dimension, producing a sparse frequency matrix. + +4. **Flatten**: Flatten the matrix into a 1D vector, with low-frequency components first. + +5. **Byte Pair Encoding (BPE)**: Train a BPE tokenizer to compress the DCT coefficients into dense action tokens, typically achieving **10x compression** over prior tokenization approaches. + +This approach can transform **any existing VLM** into a VLA by training it to predict these FAST tokens. + +## Installation Requirements + +1. Install LeRobot by following our [Installation Guide](./installation). +2. Install π₀-FAST dependencies by running: + + ```bash + pip install -e ".[pi]" + ``` + + > [!NOTE] + > For lerobot 0.4.0, if you want to install the pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`. + > + > This will be solved in the next patch release + +## Training a Custom FAST Tokenizer + +You have two options for the FAST tokenizer: + +1. **Use the pre-trained tokenizer**: The `physical-intelligence/fast` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer. + +2. **Train your own tokenizer**: For maximum performance on your specific dataset, you can finetune the tokenizer on your own data. + +### Training Your Own Tokenizer + +```bash +python src/lerobot/policies/pi0_fast/train_fast_tokenizer.py \ + --repo_id "user/my-lerobot-dataset" \ + --action_horizon 10 \ + --encoded_dims "0:6" \ + --vocab_size 1024 \ + --scale 10.0 \ + --normalization_mode QUANTILES \ + --output_dir "./my_fast_tokenizer" \ + --push_to_hub \ + --hub_repo_id "username/my-action-tokenizer" +``` + +### Key Tokenizer Parameters + +| Parameter | Description | Default | +| ---------------------- | --------------------------------------------------------------------------------- | ------------ | +| `--repo_id` | LeRobot dataset repository ID | Required | +| `--action_horizon` | Number of future actions in each chunk | `10` | +| `--encoded_dims` | Comma-separated dimension ranges to encode (e.g., `"0:6,7:23"`) | `"0:6,7:23"` | +| `--vocab_size` | BPE vocabulary size | `1024` | +| `--scale` | DCT scaling factor for quantization | `10.0` | +| `--normalization_mode` | Normalization mode (`MEAN_STD`, `MIN_MAX`, `QUANTILES`, `QUANTILE10`, `IDENTITY`) | `QUANTILES` | +| `--sample_fraction` | Fraction of chunks to sample per episode | `0.1` | + +## Usage + +To use π₀-FAST in LeRobot, specify the policy type as: + +```python +policy.type=pi0_fast +``` + +## Training + +For training π₀-FAST, you can use the LeRobot training script: + +```bash +python src/lerobot/scripts/lerobot_train.py \ + --dataset.repo_id=your_dataset \ + --policy.type=pi0_fast \ + --output_dir=./outputs/pi0fast_training \ + --job_name=pi0fast_training \ + --policy.pretrained_path=lerobot/pi0_fast_base \ + --policy.dtype=bfloat16 \ + --policy.gradient_checkpointing=true \ + --policy.chunk_size=10 \ + --policy.n_action_steps=10 \ + --policy.max_action_tokens=256 \ + --steps=100000 \ + --batch_size=4 \ + --policy.device=cuda +``` + +### Key Training Parameters + +| Parameter | Description | Default | +| -------------------------------------- | -------------------------------------------------- | ---------------------------- | +| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` | +| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` | +| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` | +| `--policy.n_action_steps` | Number of action steps to execute | `50` | +| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` | +| `--policy.action_tokenizer_name` | FAST tokenizer to use | `physical-intelligence/fast` | +| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` | + +## Inference + +### KV-Caching for Fast Inference + +π₀-FAST supports **KV-caching**, a widely used optimization in LLM inference. This caches the key-value pairs from the attention mechanism, avoiding redundant computation during autoregressive decoding. + +```python +# KV-caching is enabled by default +policy.use_kv_cache=true +``` + +### Inference Example + +```python +from lerobot.policies.pi0_fast import PI0FastPolicy, PI0FastConfig + +# Load the policy +policy = PI0FastPolicy.from_pretrained("your-model-path") + +# During inference +actions = policy.predict_action_chunk(batch) +``` + +## Model Architecture + +π₀-FAST uses a PaliGemma-based architecture: + +- **Vision Encoder**: SigLIP vision tower for image understanding +- **Language Model**: Gemma 2B for processing language instructions and predicting action tokens + +The model takes images, text instructions, and robot state as input, and outputs discrete FAST tokens that are decoded back to continuous actions. + +## Configuration Options + +| Parameter | Description | Default | +| -------------------- | ----------------------------------------------- | ---------- | +| `paligemma_variant` | VLM backbone variant (`gemma_300m`, `gemma_2b`) | `gemma_2b` | +| `max_state_dim` | Maximum state vector dimension (padded) | `32` | +| `max_action_dim` | Maximum action vector dimension (padded) | `32` | +| `temperature` | Sampling temperature (0.0 for greedy) | `0.0` | +| `max_decoding_steps` | Maximum decoding steps | `256` | +| `use_kv_cache` | Enable KV caching for faster inference | `true` | + +## Comparison with π₀ + +| Feature | π₀ | π₀-FAST | +| --------------------- | ------------------------- | ---------------------------- | +| Action Representation | Flow Matching (Diffusion) | Autoregressive Tokens (FAST) | +| Training Speed | 1x | **5x faster** | +| Dexterity | High | High | +| Inference Method | Iterative Denoising | Autoregressive Decoding | +| KV-Caching | N/A | Supported | + +## License + +This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi). + +## References + +- [FAST: Efficient Robot Action Tokenization](https://www.physicalintelligence.company/research/fast) - Physical Intelligence Blog +- [OpenPI Repository](https://github.com/Physical-Intelligence/openpi) - Original implementation +- [FAST Tokenizer on Hugging Face](https://huggingface.co/physical-intelligence/fast) - Pre-trained tokenizer diff --git a/pyproject.toml b/pyproject.toml index f25645319..75738d2de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,7 +127,7 @@ wallx = [ "torchdiffeq==0.2.5", "qwen_vl_utils==0.0.11" ] -pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"] +pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi", "scipy>=1.10.1,<1.15"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"] groot = [ "lerobot[transformers-dep]", diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index 99275e787..c7951f028 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -16,6 +16,7 @@ from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .groot.configuration_groot import GrootConfig as GrootConfig from .pi0.configuration_pi0 import PI0Config as PI0Config +from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig from .pi05.configuration_pi05 import PI05Config as PI05Config from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig from .smolvla.processor_smolvla import SmolVLANewLineProcessor @@ -29,6 +30,7 @@ __all__ = [ "DiffusionConfig", "PI0Config", "PI05Config", + "PI0FastConfig", "SmolVLAConfig", "SARMConfig", "TDMPCConfig", diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 8c414f235..fff08ad37 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -91,6 +91,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from lerobot.policies.pi0.modeling_pi0 import PI0Policy return PI0Policy + elif name == "pi0_fast": + from lerobot.policies.pi0_fast.modeling_pi0_fast import PI0FastPolicy + + return PI0FastPolicy elif name == "pi05": from lerobot.policies.pi05.modeling_pi05 import PI05Policy diff --git a/src/lerobot/policies/pi0_fast/__init__.py b/src/lerobot/policies/pi0_fast/__init__.py new file mode 100644 index 000000000..a0277da0f --- /dev/null +++ b/src/lerobot/policies/pi0_fast/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_pi0_fast import PI0FastConfig +from .modeling_pi0_fast import PI0FastPolicy +from .processor_pi0_fast import make_pi0_fast_pre_post_processors + +__all__ = ["PI0FastConfig", "PI0FastPolicy", "make_pi0_fast_pre_post_processors"] diff --git a/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py b/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py new file mode 100644 index 000000000..42aa4a132 --- /dev/null +++ b/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig +from lerobot.policies.rtc.configuration_rtc import RTCConfig + +DEFAULT_IMAGE_SIZE = 224 + + +@PreTrainedConfig.register_subclass("pi0_fast") +@dataclass +class PI0FastConfig(PreTrainedConfig): + paligemma_variant: str = "gemma_2b" + action_expert_variant: str = "gemma_300m" + dtype: str = "float32" # Options: "bfloat16", "float32" + + chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon" + n_action_steps: int = 50 # Number of action steps to execute + + # Shorter state and action vectors will be padded to these dimensions + max_state_dim: int = 32 + max_action_dim: int = 32 + max_action_tokens: int = 256 + + # Real-Time Chunking (RTC) configuration + rtc_config: RTCConfig | None = None + + image_resolution: tuple[int, int] = ( + DEFAULT_IMAGE_SIZE, + DEFAULT_IMAGE_SIZE, + ) # see openpi `preprocessing_pytorch.py` + + # Add empty images. Used to add empty cameras when no image features are present. + empty_cameras: int = 0 + + tokenizer_max_length: int = 200 # see openpi `__post_init__` + text_tokenizer_name: str = "google/paligemma-3b-pt-224" + action_tokenizer_name: str = "physical-intelligence/fast" + temperature: float = 0.0 + max_decoding_steps: int = 256 + fast_skip_tokens: int = 128 + + # Whether to validate that decoded action tokens start with "Action: " prefix + validate_action_token_prefix: bool = True + + # Whether to use KV cache for faster autoregressive decoding + use_kv_cache: bool = True + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, # Pi0Fast uses quantiles for state + "ACTION": NormalizationMode.MEAN_STD, # Pi0Fast uses quantiles for action + } + ) + + # Training settings + gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization + compile_model: bool = False # Whether to use torch.compile for model optimization + compile_mode: str = "max-autotune" # Torch compile mode + device: str | None = None # Device to use for the model (None = auto-detect) + + # Optimizer settings: see openpi `AdamW` + optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr` + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 0.01 + optimizer_grad_clip_norm: float = 1.0 + + # Scheduler settings: see openpi `CosineDecaySchedule` + # Note: These will auto-scale if --steps < scheduler_decay_steps + # For example, --steps=3000 will scale warmup to 100 and decay to 3000 + scheduler_warmup_steps: int = 1_000 + scheduler_decay_steps: int = 30_000 + scheduler_decay_lr: float = 2.5e-6 + + def __post_init__(self): + super().__post_init__() + + # Validate configuration + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})" + ) + + if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]: + raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}") + + if self.dtype not in ["bfloat16", "float32"]: + raise ValueError(f"Invalid dtype: {self.dtype}") + + def validate_features(self) -> None: + """Validate and set up input/output features.""" + for i in range(self.empty_cameras): + key = f"observation.images.empty_camera_{i}" + empty_camera = PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, *self.image_resolution), # Use configured image resolution + ) + self.input_features[key] = empty_camera + + if "observation.state" not in self.input_features: + state_feature = PolicyFeature( + type=FeatureType.STATE, + shape=(self.max_state_dim,), # Padded to max_state_dim + ) + self.input_features["observation.state"] = state_feature + + if "action" not in self.output_features: + action_feature = PolicyFeature( + type=FeatureType.ACTION, + shape=(self.max_action_dim,), # Padded to max_action_dim + ) + self.output_features["action"] = action_feature + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.optimizer_grad_clip_norm, + ) + + def get_scheduler_preset(self): + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + @property + def observation_delta_indices(self) -> None: + return None + + @property + def action_delta_indices(self) -> list: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py new file mode 100644 index 000000000..b4bc7ba22 --- /dev/null +++ b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py @@ -0,0 +1,1353 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import builtins +import logging +import math +from collections import deque +from pathlib import Path +from typing import TYPE_CHECKING, Literal, TypedDict + +import numpy as np +import torch +import torch.nn.functional as F # noqa: N812 +from torch import Tensor, nn +from typing_extensions import Unpack + +from lerobot.utils.import_utils import _scipy_available, _transformers_available + +# Conditional import for type checking and lazy loading +if TYPE_CHECKING or _scipy_available: + from scipy.fftpack import idct +else: + idct = None + +if TYPE_CHECKING or _transformers_available: + from transformers import AutoTokenizer + from transformers.models.auto import CONFIG_MAPPING + from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration +else: + CONFIG_MAPPING = None + PaliGemmaForConditionalGeneration = None + AutoTokenizer = None + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig +from lerobot.policies.pretrained import PreTrainedPolicy, T +from lerobot.policies.rtc.modeling_rtc import RTCProcessor +from lerobot.utils.constants import ( + ACTION, + ACTION_TOKEN_MASK, + ACTION_TOKENS, + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, + OPENPI_ATTENTION_MASK_VALUE, +) + + +class ActionSelectKwargs(TypedDict, total=False): + temperature: float | None + + +def pad_vector(vector, new_dim): + """Pad the last dimension of a vector to new_dim with zeros. + + Can be (batch_size x sequence_length x features_dimension) + or (batch_size x features_dimension) + """ + if vector.shape[-1] >= new_dim: + return vector + return F.pad(vector, (0, new_dim - vector.shape[-1])) + + +def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) + images: torch.Tensor, + height: int, + width: int, + mode: str = "bilinear", +) -> torch.Tensor: + """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion + by padding with black. If the image is float32, it must be in the range [-1, 1]. + + Args: + images: Tensor of shape [*b, h, w, c] or [*b, c, h, w] + height: Target height + width: Target width + mode: Interpolation mode ('bilinear', 'nearest', etc.) + + Returns: + Resized and padded tensor with same shape format as input + """ + # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w] + if images.shape[-1] <= 4: # Assume channels-last format + channels_last = True + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w] + else: + channels_last = False + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + + batch_size, channels, cur_height, cur_width = images.shape + + # Calculate resize ratio + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + + # Resize + resized_images = F.interpolate( + images, + size=(resized_height, resized_width), + mode=mode, + align_corners=False if mode == "bilinear" else None, + ) + + # Handle dtype-specific clipping + if images.dtype == torch.uint8: + resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) + elif images.dtype == torch.float32: + resized_images = resized_images.clamp(-1.0, 1.0) + else: + raise ValueError(f"Unsupported image dtype: {images.dtype}") + + # Calculate padding + pad_h0, remainder_h = divmod(height - resized_height, 2) + pad_h1 = pad_h0 + remainder_h + pad_w0, remainder_w = divmod(width - resized_width, 2) + pad_w1 = pad_w0 + remainder_w + + # Pad + constant_value = 0 if images.dtype == torch.uint8 else -1.0 + padded_images = F.pad( + resized_images, + (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom + mode="constant", + value=constant_value, + ) + + # Convert back to original format if needed + if channels_last: + padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + return padded_images + + +class GemmaConfig: # see openpi `gemma.py: Config` + """Configuration for Gemma model variants.""" + + def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim): + self.width = width + self.depth = depth + self.mlp_dim = mlp_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + + +def get_gemma_config(variant: str) -> GemmaConfig: # see openpi `gemma.py: get_config` + """Returns config for specified gemma variant.""" + if variant == "gemma_300m": + return GemmaConfig( + width=1024, + depth=18, + mlp_dim=4096, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + elif variant == "gemma_2b": + return GemmaConfig( + width=2048, + depth=18, + mlp_dim=16_384, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + else: + raise ValueError(f"Unknown variant: {variant}") + + +class PI0FastPaliGemma(nn.Module): + """PaliGemma model for PI0Fast""" + + def __init__( + self, + vlm_config, + use_adarms=None, + precision: Literal["bfloat16", "float32"] = "bfloat16", + ): + if use_adarms is None: + use_adarms = [False, False] + super().__init__() + + vlm_config_hf = CONFIG_MAPPING["paligemma"]() + vlm_config_hf._vocab_size = 257152 # noqa: SLF001 + vlm_config_hf.image_token_index = 257152 + vlm_config_hf.text_config.hidden_size = vlm_config.width + vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim + vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads + vlm_config_hf.text_config.head_dim = vlm_config.head_dim + vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth + vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads + vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh" + vlm_config_hf.text_config.torch_dtype = "float32" + vlm_config_hf.text_config.vocab_size = 257152 + vlm_config_hf.text_config.use_adarms = use_adarms[0] + vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None + vlm_config_hf.vision_config.intermediate_size = 4304 + vlm_config_hf.vision_config.projection_dim = 2048 + vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" + vlm_config_hf.vision_config.torch_dtype = "float32" + + self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf) + + self.to_bfloat16_for_selected_params(precision) + + def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): + if precision == "bfloat16": + self.to(dtype=torch.bfloat16) + elif precision == "float32": + self.to(dtype=torch.float32) + return + else: + raise ValueError(f"Invalid precision: {precision}") + + params_to_keep_float32 = [ + "vision_tower.vision_model.embeddings.patch_embedding.weight", + "vision_tower.vision_model.embeddings.patch_embedding.bias", + "vision_tower.vision_model.embeddings.position_embedding.weight", + "input_layernorm", + "post_attention_layernorm", + "model.norm", + ] + + for name, param in self.named_parameters(): + if any(selector in name for selector in params_to_keep_float32): + param.data = param.data.to(dtype=torch.float32) + + def embed_image(self, image: torch.Tensor): + return self.paligemma.model.get_image_features(image) + + def embed_language_tokens(self, tokens: torch.Tensor): + return self.paligemma.language_model.embed_tokens(tokens) + + def forward( + self, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: list[torch.FloatTensor] | None = None, + use_cache: bool | None = None, + adarms_cond: list[torch.Tensor] | None = None, + ): + if adarms_cond is None: + adarms_cond = [None, None] + if inputs_embeds[1] is None: + prefix_output = self.paligemma.language_model.forward( + inputs_embeds=inputs_embeds[0], + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + adarms_cond=adarms_cond[0] if adarms_cond is not None else None, + ) + prefix_past_key_values = prefix_output.past_key_values + # prefix_output to be used for the language head + # shape: [batch_size, seq_len, hidden_size] with hidden_size = 2048 + prefix_output = prefix_output.last_hidden_state + suffix_output = None + return [prefix_output, suffix_output], prefix_past_key_values + + +class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` + """Core PI0Fast PyTorch model.""" + + def __init__( + self, + config: PI0FastConfig, + rtc_processor: RTCProcessor | None = None, + paligemma_tokenizer: "AutoTokenizer | None" = None, + ): + super().__init__() + self.config = config + self.rtc_processor = rtc_processor + self._paligemma_tokenizer = paligemma_tokenizer + + paligemma_config = get_gemma_config(config.paligemma_variant) + + self.paligemma_with_expert = PI0FastPaliGemma( + paligemma_config, + use_adarms=[False, True], + precision=config.dtype, + ) + + # Initialize gradient checkpointing flag + self.gradient_checkpointing_enabled = False + + # Compile model if requested + if config.compile_model: + torch.set_float32_matmul_precision("high") + self.sample_actions_fast = torch.compile(self.sample_actions_fast, mode=config.compile_mode) + self.forward = torch.compile(self.forward, mode=config.compile_mode) + + msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues""" + + try: + from transformers.models.siglip import check + + if not check.check_whether_transformers_replace_is_installed_correctly(): + raise ValueError(msg) + except ImportError: + raise ValueError(msg) from None + + def gradient_checkpointing_enable(self): + """Enable gradient checkpointing for memory optimization.""" + self.gradient_checkpointing_enabled = True + # Call the proper gradient_checkpointing_enable() method with use_reentrant=False for better memory efficiency + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": False} + ) + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": False} + ) + logging.info("Enabled gradient checkpointing for PI0FastPytorch model") + + def gradient_checkpointing_disable(self): + """Disable gradient checkpointing.""" + self.gradient_checkpointing_enabled = False + # Call the proper gradient_checkpointing_disable() method + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_disable() + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_disable() + logging.info("Disabled gradient checkpointing for PI0FastPytorch model") + + def _apply_checkpoint(self, func, *args, **kwargs): + """Helper method to apply gradient checkpointing if enabled.""" + if self.gradient_checkpointing_enabled and self.training: + return torch.utils.checkpoint.checkpoint( + func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs + ) + return func(*args, **kwargs) + + def _prepare_attention_masks_4d(self, att_2d_masks, dtype=None): + """Helper method to prepare 4D attention masks for transformer.""" + att_2d_masks_4d = att_2d_masks[:, None, :, :] + result = torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE) + if dtype is not None: + result = result.to(dtype=dtype) + return result + + def embed_prefix_fast( + self, + images, + img_masks, + tokens, + masks, + fast_action_tokens=None, + fast_action_masks=None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]: + """Embed images, language tokens, and FAST action tokens. + + Attention pattern: + - Images + Language: bidirectional among themselves + - FAST: attend to images + language, causal among themselves + + Args: + images: List of image tensors + img_masks: List of image masks + tokens: Language instruction tokens + masks: Attention masks for tokens + fast_action_tokens: FAST action tokens (discrete token IDs) + fast_action_masks: Padding masks for FAST action tokens + + Returns: + embs: Concatenated embeddings [images, tokens, fast_action_tokens] + pad_masks: Padding masks + att_masks: 2D attention mask + total_T_images: Total number of image tokens + num_fast_embs: Number of FAST action token embeddings + """ + embs = [] + pad_masks = [] + att_mask_segments = [] + total_t_images = 0 + num_fast_embs = 0 + + # Process images + for img, img_mask in zip(images, img_masks, strict=True): + + def image_embed_func(img): + return self.paligemma_with_expert.embed_image(img) + + img_emb = self._apply_checkpoint(image_embed_func, img) + bsize, num_img_embs = img_emb.shape[:2] + + embs.append(img_emb) + pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) + att_mask_segments.append(("image", num_img_embs)) + total_t_images += num_img_embs + + # Process language instruction tokens + def lang_embed_func(tokens): + lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens) + lang_emb_dim = lang_emb.shape[-1] + return lang_emb * math.sqrt(lang_emb_dim) + + lang_emb = self._apply_checkpoint(lang_embed_func, tokens) + embs.append(lang_emb) + pad_masks.append(masks) + + num_lang_embs = lang_emb.shape[1] + att_mask_segments.append(("language", num_lang_embs)) + + # Process FAST action tokens (discrete token IDs) + if fast_action_tokens is not None: + + def fast_action_embed_func(fast_action_tokens): + fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens) + fast_emb_dim = fast_emb.shape[-1] + return fast_emb * math.sqrt(fast_emb_dim) + + fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens) + embs.append(fast_action_emb) + + num_fast_embs = fast_action_tokens.shape[1] + pad_masks.append(fast_action_masks) + att_mask_segments.append(("fast", num_fast_embs)) + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + + # Create custom 2D attention mask: + # - Images + Language: bidirectional among themselves + # - FAST: attend to images + language, causal among themselves + att_masks = self._create_custom_attention_mask_fast(att_mask_segments, pad_masks, bsize) + + return embs, pad_masks, att_masks, total_t_images, num_fast_embs + + def _create_custom_attention_mask_fast(self, att_mask_segments, pad_masks, bsize): + """Create custom 2D attention mask. + + Attention rules: + - Images + Language: bidirectional among themselves + - FAST: attend to images + language, causal among themselves + """ + total_len = sum(length for _, length in att_mask_segments) + device = pad_masks.device + + att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device) + + positions = [] + current_pos = 0 + for seg_type, seg_len in att_mask_segments: + positions.append((seg_type, current_pos, current_pos + seg_len)) + current_pos += seg_len + + for _i, (query_type, query_start, query_end) in enumerate(positions): + for _j, (key_type, key_start, key_end) in enumerate(positions): + # Images and Language can attend to each other bidirectionally + if ( + query_type in ["image", "language"] + and key_type in ["image", "language"] + or query_type == "fast" + and key_type in ["image", "language"] + ): + att_2d_masks[:, query_start:query_end, key_start:key_end] = True + + # FAST tokens attend causally to themselves + elif query_type == "fast" and key_type == "fast": + fast_len = query_end - query_start + causal_mask = torch.tril(torch.ones(fast_len, fast_len, dtype=torch.bool, device=device)) + att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :] + + # Apply padding masks + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + att_2d_masks = att_2d_masks & pad_2d_masks + + return att_2d_masks + + def forward( + self, + images, + img_masks, + tokens, + masks, + fast_action_tokens, + fast_action_masks, + ) -> dict: + """Forward pass for PI0Fast. + + This implements the Pi0FAST training objective: predict next action token + using cross-entropy loss. + + Args: + images: List of image tensors + img_masks: List of image masks + tokens: Language instruction tokens + masks: Attention masks for tokens + fast_action_tokens: Discrete action token IDs [B, max_action_tokens] + fast_action_masks: Padding masks for fast action tokens [B, max_action_tokens] + + Returns: + Dictionary with 'fast_loss' and 'loss' keys + """ + if fast_action_tokens is None or fast_action_masks is None: + raise ValueError("fast_action_tokens and fast_action_masks are required for FAST-only mode") + + # Embed prefix with FAST tokens + prefix_embs, prefix_pad_masks, prefix_att_masks, total_t_images, num_fast_embs = ( + self.embed_prefix_fast( + images, + img_masks, + tokens, + masks, + fast_action_tokens=fast_action_tokens, + fast_action_masks=fast_action_masks, + ) + ) + + # Convert embeddings to bfloat16 if needed + if ( + self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + == torch.bfloat16 + ): + prefix_embs = prefix_embs.to(dtype=torch.bfloat16) + + # for next-token prediction, input tokens [0:T-1] to predict tokens [1:T] + input_embs = prefix_embs + input_pad_masks = prefix_pad_masks + input_att_masks = prefix_att_masks + + position_ids = torch.cumsum(input_pad_masks, dim=1) - 1 + att_2d_4d = self._prepare_attention_masks_4d(input_att_masks, dtype=input_embs.dtype) + + # forward pass through paligemma (language model) + (prefix_out, _), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[input_embs, None], # No suffix/action expert + use_cache=False, + adarms_cond=[None, None], + ) + + # Get logits for FAST action tokens using the FAST LM head + # only compute logits for the positions that predict FAST tokens + lm_head = self.paligemma_with_expert.paligemma.lm_head + + # Targets are the FAST action tokens + fast_targets = fast_action_tokens # (B, num_fast_embs) + + # extract logits for FAST token prediction + fast_hidden = prefix_out[:, -fast_targets.shape[1] :, :] + fast_logits_for_pred = lm_head(fast_hidden) # (B, num_fast_embs, gemma_vocab_size) + + # Shift left for next-step prediction and shift target + # logits[:, i] predicts targets[:, i+1] + fast_logits_for_pred = fast_logits_for_pred[:, :-1, :] # shift logits left + fast_targets = fast_targets[:, 1:] # shift targets right + fast_action_masks = fast_action_masks[:, 1:] # shift masks to match targets + + # compute cross-entropy loss + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1)) + fast_targets_flat = fast_targets.reshape(-1) + + fast_loss_per_token = loss_fct(fast_logits_flat, fast_targets_flat) + fast_loss_per_token = fast_loss_per_token.reshape(fast_targets.shape) + + # apply mask and compute mean loss + masked_fast_loss = fast_loss_per_token * fast_action_masks.float() + fast_loss = masked_fast_loss.sum() / fast_action_masks.sum().clamp(min=1) + + return { + "ce_loss": fast_loss, + "loss": fast_loss, + } + + @torch.no_grad() + def sample_actions_fast( + self, + images, + img_masks, + tokens, + masks, + max_decoding_steps=None, + temperature=0.0, + ) -> torch.Tensor: + """ + Inefficient but safe autoregressive decoding for FAST tokens. + Matches the pattern of _generate_subtask_tokens. + TODO: jadechoghari, should we move this logic to PI0FastPolicy class? + """ + if max_decoding_steps is None: + max_decoding_steps = self.config.max_action_tokens + + bsize = tokens.shape[0] + device = tokens.device + lm_head = self.paligemma_with_expert.paligemma.lm_head + + # add bos token after tokens + bos_token = torch.full( + (bsize, 1), self._paligemma_tokenizer.bos_token_id, dtype=torch.long, device=device + ) + tokens = torch.cat([tokens, bos_token], dim=1) + masks = torch.cat([masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1) + + # 1. Initial Embedding (matches training prefix) + # prefix_embs will include [Images, Language Prompt, BOS] + prefix_embs, prefix_pad_masks, prefix_att_masks, total_t_images, _ = self.embed_prefix_fast( + images, img_masks, tokens, masks, fast_action_tokens=None, fast_action_masks=None + ) + + if ( + self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + == torch.bfloat16 + ): + prefix_embs = prefix_embs.to(dtype=torch.bfloat16) + + generated_action_tokens = torch.zeros((bsize, max_decoding_steps), dtype=torch.long, device=device) + + # 2. Decoding Loop (each step re-computes full sequence) + for t in range(max_decoding_steps): + # always re-calculate position IDs from the current pad mask + position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + att_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype) + + # full forward pass (no kv cache) + (prefix_out, _), _ = self.paligemma_with_expert.forward( + attention_mask=att_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=False, + adarms_cond=[None, None], + ) + + # predict next token from the very last sequence position + last_logits = lm_head(prefix_out[:, -1:, :]) # (B, 1, vocab_size) + + if temperature > 0: + probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(last_logits[:, -1], dim=-1, keepdim=True) + + generated_action_tokens[:, t] = next_token.squeeze(-1) + + # 3. Update sequence for next iteration (unless it's the last step) + if t < max_decoding_steps - 1: + # embed the newly generated token + next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token) + next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1]) + if prefix_embs.dtype == torch.bfloat16: + next_token_emb = next_token_emb.to(dtype=torch.bfloat16) + + # append to embeddings + prefix_embs = torch.cat([prefix_embs, next_token_emb], dim=1) + + # update padding mask (new token is always valid/1) + prefix_pad_masks = torch.cat( + [prefix_pad_masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1 + ) + + # update 2d attention mask: grow the matrix + old_len = prefix_att_masks.shape[1] + new_len = old_len + 1 + new_att_masks = torch.zeros((bsize, new_len, new_len), dtype=torch.bool, device=device) + new_att_masks[:, :old_len, :old_len] = prefix_att_masks + # new token attends to all non-padding tokens in the updated sequence + new_att_masks[:, -1, :] = prefix_pad_masks + prefix_att_masks = new_att_masks + return generated_action_tokens + + @torch.no_grad() + def sample_actions_fast_kv_cache( + self, + images, + img_masks, + tokens, + masks, + max_decoding_steps=None, + temperature=0.0, + ) -> torch.Tensor: + """ + Optimized autoregressive decoding for FAST tokens using KV Caching. + """ + if max_decoding_steps is None: + max_decoding_steps = self.config.max_action_tokens + + bsize = tokens.shape[0] + device = tokens.device + lm_head = self.paligemma_with_expert.paligemma.lm_head + + # --- 1. PREFILL PHASE --- + # Process Images + Text Prompt + BOS token once to populate the KV cache. + + # Add BOS token to the prompt + bos_token = torch.full( + (bsize, 1), self._paligemma_tokenizer.bos_token_id, dtype=torch.long, device=device + ) + tokens_in = torch.cat([tokens, bos_token], dim=1) + masks_in = torch.cat([masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1) + + # Embed prefix [Images, Language, BOS] + # fast_action_tokens=None means we are just embedding the condition (images+text) + prefix_embs, prefix_pad_masks, prefix_att_masks, total_t_images, _ = self.embed_prefix_fast( + images, img_masks, tokens_in, masks_in, fast_action_tokens=None, fast_action_masks=None + ) + + # Ensure correct precision (bfloat16/float32) + if ( + self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + == torch.bfloat16 + ): + prefix_embs = prefix_embs.to(dtype=torch.bfloat16) + + # Create position IDs (cumsum of mask - 1) + position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + + # Create 4D mask for the prefix + att_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype) + + # Forward pass (Prefill) with use_cache=True + # We only pass [prefix_embs, None] because we aren't using the suffix (expert) model yet + (prefix_out, _), past_key_values = self.paligemma_with_expert.forward( + attention_mask=att_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=True, # Enable caching + adarms_cond=[None, None], + ) + + # Sample the first action token from the last logit of the prefix + last_logits = lm_head(prefix_out[:, -1:, :]) # (B, 1, V) + if temperature > 0: + probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(last_logits[:, -1], dim=-1, keepdim=True) + + # Initialize storage for generated tokens + generated_action_tokens = torch.zeros((bsize, max_decoding_steps), dtype=torch.long, device=device) + generated_action_tokens[:, 0] = next_token.squeeze(-1) + + # Track valid tokens mask (0 for pad, 1 for valid) + # We need this to tell the new token what it can attend to (images + text + past actions) + current_pad_mask = prefix_pad_masks + + # --- 2. DECODING PHASE --- + # Generate remaining tokens one by one using the cache. + + for t in range(1, max_decoding_steps): + # Embed the single previous token + # We use embed_language_tokens directly to avoid overhead of full prefix embedding + next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token) + next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1]) + if prefix_embs.dtype == torch.bfloat16: + next_token_emb = next_token_emb.to(dtype=torch.bfloat16) + + # Update Pad Mask: append 1s for the new valid token + new_column = torch.ones((bsize, 1), dtype=torch.bool, device=device) + current_pad_mask = torch.cat([current_pad_mask, new_column], dim=1) + + # Update Position IDs for the single new token + current_position_ids = (torch.sum(current_pad_mask, dim=1, keepdim=True) - 1).long() + + # Create Attention Mask for the single new step + # The new token attends to all valid tokens in history (captured by current_pad_mask). + # Shape becomes (B, 1, 1, Total_Len) which works with HF's cache logic. + step_att_mask = self._prepare_attention_masks_4d( + current_pad_mask.unsqueeze(1), dtype=next_token_emb.dtype + ) + + # Forward pass (Decoding step) + # input_embeds is just the new token (B, 1, D) + (step_out, _), past_key_values = self.paligemma_with_expert.forward( + attention_mask=step_att_mask, + position_ids=current_position_ids, + past_key_values=past_key_values, # Pass updated cache + inputs_embeds=[next_token_emb, None], + use_cache=True, + adarms_cond=[None, None], + ) + + # Sample next token + last_logits = lm_head(step_out[:, -1:, :]) + if temperature > 0: + probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(last_logits[:, -1], dim=-1, keepdim=True) + + generated_action_tokens[:, t] = next_token.squeeze(-1) + + return generated_action_tokens + + +class PI0FastPolicy(PreTrainedPolicy): + """PI0Fast Policy for LeRobot.""" + + config_class = PI0FastConfig + name = "pi0_fast" + + def __init__( + self, + config: PI0FastConfig, + **kwargs, + ): + """ + Args: + config: Policy configuration class instance. + """ + super().__init__(config) + config.validate_features() + self.config = config + + # Load tokenizers first + try: + from transformers import AutoProcessor, AutoTokenizer + + # Load FAST tokenizer + self.action_tokenizer = AutoProcessor.from_pretrained( + config.action_tokenizer_name, trust_remote_code=True + ) + + # Load PaliGemma tokenizer for token conversion + self._paligemma_tokenizer = AutoTokenizer.from_pretrained( + config.text_tokenizer_name, trust_remote_code=True, add_eos_token=True, add_bos_token=False + ) + + logging.info("Loaded FAST tokenizer for action detokenization") + except Exception as e: + logging.error(f"Failed to load FAST tokenizer for action detokenization: {e}") + logging.error("Tokenizer loading is required for proper policy initialization; aborting.") + raise RuntimeError("Failed to load required tokenizers for PI0FastPolicy initialization") from e + + # Initialize the core PI0Fast model + self.init_rtc_processor() + self.model = PI0FastPytorch( + config, rtc_processor=self.rtc_processor, paligemma_tokenizer=self._paligemma_tokenizer + ) + + # Enable gradient checkpointing if requested + if config.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + + self.model.to(config.device) + + self.reset() + + @classmethod + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + config: PreTrainedConfig | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + strict: bool = True, + **kwargs, + ) -> T: + """Override the from_pretrained method to handle key remapping and display important disclaimer.""" + print( + "The PI0Fast model is a direct port of the OpenPI implementation. \n" + "This implementation follows the original OpenPI structure for compatibility. \n" + "Original implementation: https://github.com/Physical-Intelligence/openpi" + ) + if pretrained_name_or_path is None: + raise ValueError("pretrained_name_or_path is required") + + # Use provided config if available, otherwise create default config + if config is None: + config = PreTrainedConfig.from_pretrained( + pretrained_name_or_path=pretrained_name_or_path, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + **kwargs, + ) + + # Initialize model without loading weights + # Check if dataset_stats were provided in kwargs + model = cls(config, **kwargs) + + # Now manually load and remap the state dict + try: + # Try to load the pytorch_model.bin or model.safetensors file + print(f"Loading model from: {pretrained_name_or_path}") + try: + from transformers.utils import cached_file + + # Try safetensors first + resolved_file = cached_file( + pretrained_name_or_path, + "model.safetensors", + cache_dir=kwargs.get("cache_dir"), + force_download=kwargs.get("force_download", False), + resume_download=kwargs.get("resume_download"), + proxies=kwargs.get("proxies"), + use_auth_token=kwargs.get("use_auth_token"), + revision=kwargs.get("revision"), + local_files_only=kwargs.get("local_files_only", False), + ) + from safetensors.torch import load_file + + original_state_dict = load_file(resolved_file) + print("✓ Loaded state dict from model.safetensors") + except Exception as e: + print(f"Could not load state dict from remote files: {e}") + print("Returning model without loading pretrained weights") + return model + + # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` + fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) + # Then add "model." prefix for all keys that don't already have it + remapped_state_dict = {} + remap_count = 0 + + for key, value in fixed_state_dict.items(): + if not key.startswith("model."): + new_key = f"model.{key}" + remapped_state_dict[new_key] = value + remap_count += 1 + if remap_count <= 10: # Only print first 10 to avoid spam + print(f"Remapped: {key} -> {new_key}") + else: + remapped_state_dict[key] = value + + if remap_count > 0: + print(f"Remapped {remap_count} state dict keys") + + # Load the remapped state dict into the model + missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) + + if missing_keys: + print(f"Missing keys when loading state dict: {len(missing_keys)} keys") + if len(missing_keys) <= 5: + for key in missing_keys: + print(f" - {key}") + else: + for key in missing_keys[:5]: + print(f" - {key}") + print(f" ... and {len(missing_keys) - 5} more") + + if unexpected_keys: + print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys") + if len(unexpected_keys) <= 5: + for key in unexpected_keys: + print(f" - {key}") + else: + for key in unexpected_keys[:5]: + print(f" - {key}") + print(f" ... and {len(unexpected_keys) - 5} more") + + if not missing_keys and not unexpected_keys: + print("All keys loaded successfully!") + + except Exception as e: + print(f"Warning: Could not remap state dict keys: {e}") + + return model + + def _fix_pytorch_state_dict_keys( + self, state_dict, model_config + ): # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys` + """Fix state dict keys to match current model architecture.""" + + fixed_state_dict = {} + + for key, value in state_dict.items(): + new_key = key + + # Handle vision tower embedding layer potential differences + if "patch_embedding" in key: + # Some checkpoints might have this, but current model expects different structure + logging.warning(f"Vision embedding key might need handling: {key}") + + if ( + key == "model.paligemma_with_expert.paligemma.lm_head.weight" + or key == "paligemma_with_expert.paligemma.lm_head.weight" + ): + fixed_state_dict[ + "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" + ] = value.clone() + + fixed_state_dict[new_key] = value + + return fixed_state_dict + + def get_optim_params(self) -> dict: + return self.parameters() + + def reset(self): + """Reset internal state - called when environment resets.""" + self._action_queue = deque(maxlen=self.config.n_action_steps) + self._queues = { + ACTION: deque(maxlen=self.config.n_action_steps), + } + + def init_rtc_processor(self): + """Initialize RTC processor if RTC is enabled in config.""" + self.rtc_processor = None + + # Create processor if config provided + # If RTC is not enabled - we can still track the denoising data + if self.config.rtc_config is not None: + self.rtc_processor = RTCProcessor(self.config.rtc_config) + + model_value = getattr(self, "model", None) + if model_value is not None: + model_value.rtc_processor = self.rtc_processor + + def _rtc_enabled(self) -> bool: + return self.config.rtc_config is not None and self.config.rtc_config.enabled + + def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]: + """Preprocess images for the model. + + Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1]. + PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1]. + """ + images = [] + img_masks = [] + + # Get device from model parameters + device = next(self.parameters()).device + + present_img_keys = [key for key in self.config.image_features if key in batch] + missing_img_keys = [key for key in self.config.image_features if key not in batch] + + if len(present_img_keys) == 0: + raise ValueError( + f"All image features are missing from the batch. At least one expected. " + f"(batch: {batch.keys()}) (image_features: {self.config.image_features})" + ) + + # Preprocess image features present in the batch + for key in present_img_keys: + img = batch[key] + + # Ensure tensor is on the same device as the model + if img.device != device: + img = img.to(device) + + # Ensure float32 dtype for consistency + if img.dtype != torch.float32: + img = img.to(torch.float32) + + # from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats + is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1 + + if is_channels_first: + # Convert [B, C, H, W] to [B, H, W, C] for processing + img = img.permute(0, 2, 3, 1) + + # from openpi preprocess_observation_pytorch: Resize with padding if needed + if img.shape[1:3] != self.config.image_resolution: + img = resize_with_pad_torch(img, *self.config.image_resolution) + + # Normalize from [0,1] to [-1,1] as expected by siglip + img = img * 2.0 - 1.0 + + # from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first + if is_channels_first: + img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + + images.append(img) + # Create mask (all ones for real images) + bsize = img.shape[0] + mask = torch.ones(bsize, dtype=torch.bool, device=device) + img_masks.append(mask) + + # Create image features not present in the batch as fully 0 padded images + for _num_empty_cameras in range(len(missing_img_keys)): + img = torch.ones_like(img) * -1 # Padded with -1 for SigLIP + mask = torch.zeros_like(mask) # Mask is zero for empty cameras + images.append(img) + img_masks.append(mask) + + return images, img_masks + + def prepare_action(self, batch): + """Pad action""" + actions = pad_vector(batch[ACTION], self.config.max_action_dim) + return actions + + def _paligemma_tokens_to_act_tokens(self, tokens: torch.Tensor) -> torch.Tensor: + """ + Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens). + + Args: + tokens: PaliGemma token IDs + + Returns: + Action token IDs + """ + return self._paligemma_tokenizer.vocab_size - 1 - self.config.fast_skip_tokens - tokens + + def decode_actions_with_fast( + self, token_ids: list[int], time_horizon: int, action_dim: int, relaxed_decoding: bool = True + ) -> np.ndarray: + """ + Decodes action token IDs back to continuous action values using the FAST tokenizer. + + Args: + token_ids: List of token IDs to decode. + time_horizon: The number of timesteps for actions. + action_dim: The dimensionality of each action. + relaxed_decoding: Whether to use relaxed decoding (allows partial sequences). + + Returns: + A numpy array representing the decoded actions. + """ + decoded_actions = [] + + for token in token_ids: + try: + decoded_tokens = self.action_tokenizer.bpe_tokenizer.decode(token) + decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.action_tokenizer.min_token + + if relaxed_decoding: + # expected sequence length + expected_seq_len = time_horizon * action_dim + diff = expected_seq_len - decoded_dct_coeff.shape[0] + + # apply truncation if too long + if diff < 0: + decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # truncate on the right + + # apply padding if too short + elif diff > 0: + decoded_dct_coeff = np.pad( + decoded_dct_coeff, (0, diff), mode="constant", constant_values=0 + ) + + decoded_dct_coeff = decoded_dct_coeff.reshape(-1, action_dim) + assert decoded_dct_coeff.shape == ( + time_horizon, + action_dim, + ), ( + f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({time_horizon}, {action_dim})" + ) + + except Exception as e: + logging.warning(f"Error decoding tokens: {e}") + logging.warning(f"Tokens: {token}") + decoded_dct_coeff = np.zeros((time_horizon, action_dim)) + + decoded_actions.append( + idct(decoded_dct_coeff / self.action_tokenizer.scale, axis=0, norm="ortho") + ) + + return np.stack(decoded_actions) + + def detokenize_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor: + """ + Detokenizes action tokens back to continuous actions. + + This method converts predicted action tokens from the model back to continuous action values + using the FAST tokenizer. It handles the conversion from PaliGemma token space to action token + space, then decodes the action tokens to continuous values using DCT decoding. + + Args: + tokens: The input tensor of tokenized outputs. Shape: (B, seq_len) or (seq_len,) + action_horizon: The number of timesteps for actions. + action_dim: The dimensionality of each action. + + Returns: + The continuous action tensor. Shape: (B, action_horizon, action_dim) or (action_horizon, action_dim) + """ + if self.action_tokenizer is None or self._paligemma_tokenizer is None: + raise ValueError( + "Action tokenizer not initialized. Make sure fast_only=True in config and tokenizers loaded successfully." + ) + + # Handle single sample (add batch dimension) + single_sample = tokens.dim() == 1 + if single_sample: + tokens = tokens.unsqueeze(0) + + # Convert token IDs to token strings + decoded_tokens = [self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist()) for seq in tokens] + # Get the token sequence for "Action: " to remove it + action_prefix_ids = self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False) + action_prefix_tokens = self._paligemma_tokenizer.convert_ids_to_tokens(action_prefix_ids) + action_prefix_len = len(action_prefix_tokens) + + # Clean tokens by removing everything after the first "|" (end-of-action marker) + # and removing all occurrences of "Action: " token sequence + # assert that beginning contain "Action: " + if self.config.validate_action_token_prefix: + for token_seq in decoded_tokens: + assert len(token_seq) >= 2 and token_seq[0] == "Action" and token_seq[1] == ":", ( + f"Token sequence does not start with ['Action', ':']: {token_seq}" + ) + + cleaned_tokens = [] + for token_seq in decoded_tokens: + # Remove everything after "|" + if "|" in token_seq: + token_seq = token_seq[: token_seq.index("|")] + + # Remove all occurrences of "Action: " token sequence + i = 0 + while i <= len(token_seq) - action_prefix_len: + if token_seq[i : i + action_prefix_len] == action_prefix_tokens: + # Found a match, remove it + token_seq = token_seq[:i] + token_seq[i + action_prefix_len :] + else: + i += 1 + + cleaned_tokens.append(token_seq) + + # Convert token strings back to IDs + raw_action_tokens = [ + torch.tensor( + self._paligemma_tokenizer.convert_tokens_to_ids(token_seq), + dtype=torch.long, + device=tokens.device, + ) + for token_seq in cleaned_tokens + ] + + # Convert PaliGemma tokens to action tokens + action_tokens = [ + self._paligemma_tokens_to_act_tokens(raw_action_token) for raw_action_token in raw_action_tokens + ] + + # Decode action tokens to continuous actions + actions = self.decode_actions_with_fast( + action_tokens, time_horizon=action_horizon, action_dim=action_dim + ) + + # Convert to tensor and return + actions_tensor = torch.tensor(actions, dtype=torch.float32, device=tokens.device) + + # Remove batch dimension if input was single sample + if single_sample: + actions_tensor = actions_tensor.squeeze(0) + + return actions_tensor + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations.""" + assert not self._rtc_enabled(), ( + "RTC is not supported for select_action, use it with predict_action_chunk" + ) + + self.eval() + + # Action queue logic for n_action_steps > 1 + if len(self._action_queue) == 0: + actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] + # Transpose to get shape (n_action_steps, batch_size, action_dim) + self._action_queue.extend(actions.transpose(0, 1)) + + return self._action_queue.popleft() + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + self.eval() + # Prepare inputs + images, img_masks = self._preprocess_images(batch) + + # FAST-only mode: use autoregressive decoding + tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] + masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + + # Get decoding parameters + temperature = self.config.temperature + max_decoding_steps = self.config.max_decoding_steps + + # Sample action tokens autoregressively + if self.config.use_kv_cache: + action_tokens = self.model.sample_actions_fast_kv_cache( + images, + img_masks, + tokens, + masks, + max_decoding_steps=max_decoding_steps, + temperature=temperature, + ) + else: + action_tokens = self.model.sample_actions_fast( + images, + img_masks, + tokens, + masks, + max_decoding_steps=max_decoding_steps, + temperature=temperature, + ) + + # Detokenize action tokens to continuous actions + action_horizon = self.config.n_action_steps + action_dim = self.config.output_features[ACTION].shape[0] + + continuous_actions = self.detokenize_actions( + action_tokens, action_horizon=action_horizon, action_dim=action_dim + ) + + return continuous_actions + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: + """Run the batch through the model and compute the loss for training.""" + + # Prepare inputs + images, img_masks = self._preprocess_images(batch) + + # Get FAST action tokens from batch + fast_action_tokens = batch.get(ACTION_TOKENS) # (B, max_action_tokens) + fast_action_masks = batch.get(ACTION_TOKEN_MASK) # (B, max_action_tokens) + + # Use full language tokens (no separation into high_level_task and subtask) + tokens = batch.get(OBS_LANGUAGE_TOKENS) + masks = batch.get(OBS_LANGUAGE_ATTENTION_MASK) + + if fast_action_tokens is None or fast_action_masks is None: + raise ValueError( + f"PI0Fast requires {ACTION_TOKENS} and {ACTION_TOKEN_MASK} to be present in the batch" + ) + + loss_dict = self.model.forward( + images, + img_masks, + tokens, + masks, + fast_action_tokens, + fast_action_masks, + ) + + loss = loss_dict["loss"] + detailed_loss_dict = { + "loss": loss.item(), + "ce_loss": loss_dict["ce_loss"].item(), + } + return loss, detailed_loss_dict diff --git a/src/lerobot/policies/pi0_fast/processor_pi0_fast.py b/src/lerobot/policies/pi0_fast/processor_pi0_fast.py new file mode 100644 index 000000000..0d9dac673 --- /dev/null +++ b/src/lerobot/policies/pi0_fast/processor_pi0_fast.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig +from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector +from lerobot.processor import ( + ActionTokenizerProcessorStep, + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStep, + ProcessorStepRegistry, + RenameObservationsProcessorStep, + TokenizerProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.utils.constants import ( + OBS_STATE, + POLICY_POSTPROCESSOR_DEFAULT_NAME, + POLICY_PREPROCESSOR_DEFAULT_NAME, +) + + +@ProcessorStepRegistry.register(name="pi0_fast_prepare_state_tokenizer_processor_step") +@dataclass +class Pi0FastPrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep): + """ + Processor step to prepare the state and tokenize the language input. + """ + + max_state_dim: int = 32 + task_key: str = "task" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + transition = transition.copy() + + state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE) + if state is None: + raise ValueError("State is required for PI0Fast") + tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key) + if tasks is None: + raise ValueError("No task found in complementary data") + + # TODO: check if this necessary + state = deepcopy(state) + + # Prepare state (pad to max_state_dim) + state = pad_vector(state, self.max_state_dim) + + # State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step + # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) + state_np = state.cpu().numpy() + discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 + + full_prompts = [] + for i, task in enumerate(tasks): + cleaned_text = task.strip().replace("_", " ").replace("\n", " ") + state_str = " ".join(map(str, discretized_states[i])) + full_prompt = f"Task: {cleaned_text}, State: {state_str};\n" + full_prompts.append(full_prompt) + + transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts + # Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!) + # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + This step does not alter the feature definitions. + """ + return features + + +def make_pi0_fast_pre_post_processors( + config: PI0FastConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the PI0Fast policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Appending a newline character to the task description for tokenizer compatibility. + 5. Tokenizing the text prompt using the PaliGemma tokenizer. + 6. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the PI0Fast policy. + dataset_stats: A dictionary of statistics for normalization. + preprocessor_kwargs: Additional arguments for the pre-processor pipeline. + postprocessor_kwargs: Additional arguments for the post-processor pipeline. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + # Add remaining processors + input_steps: list[ProcessorStep] = [ + RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one + AddBatchDimensionProcessorStep(), + # NOTE: NormalizerProcessorStep MUST come before Pi0FastPrepareStateAndLanguageTokenizerProcessorStep + # because the tokenizer step expects normalized state in [-1, 1] range for discretization + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + Pi0FastPrepareStateAndLanguageTokenizerProcessorStep(max_state_dim=config.max_state_dim), + TokenizerProcessorStep( + tokenizer_name=config.text_tokenizer_name, + max_length=config.tokenizer_max_length, + padding_side="right", + padding="max_length", + ), + ActionTokenizerProcessorStep( + action_tokenizer_name=config.action_tokenizer_name, + max_action_tokens=config.max_action_tokens, + fast_skip_tokens=config.fast_skip_tokens, + paligemma_tokenizer_name=config.text_tokenizer_name, + ), + DeviceProcessorStep(device=config.device), + ] + + output_steps: list[ProcessorStep] = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/pi0_fast/train_fast_tokenizer.py b/src/lerobot/policies/pi0_fast/train_fast_tokenizer.py new file mode 100644 index 000000000..6a3a1fe69 --- /dev/null +++ b/src/lerobot/policies/pi0_fast/train_fast_tokenizer.py @@ -0,0 +1,539 @@ +"""Train FAST tokenizer for action encoding. + +This script: +1. Loads action chunks from LeRobotDataset (with sampling) +2. Applies delta transforms and per-timestamp normalization +3. Trains FAST tokenizer on specified action dimensions +4. Saves tokenizer to assets directory +5. Reports compression statistics +""" + +import json +from pathlib import Path + +import numpy as np +import torch +import tyro +from huggingface_hub import HfApi +from transformers import AutoProcessor + +from lerobot.configs.types import NormalizationMode +from lerobot.datasets.lerobot_dataset import LeRobotDataset + + +def apply_delta_transform(state: np.ndarray, actions: np.ndarray, delta_dims: list[int] | None) -> np.ndarray: + """Apply delta transform to specified dimensions. + + Args: + state: Current state [D] + actions: Future actions [D] + delta_dims: List of dimension indices to apply delta transform to + + Returns: + Transformed actions [D] + """ + if delta_dims is None or len(delta_dims) == 0: + return actions + + delta_actions = actions.copy() + for dim in delta_dims: + delta_actions[dim] = actions[dim] - state[dim] + + return delta_actions + + +def apply_normalization( + data: np.ndarray, + stats: dict[str, np.ndarray], + mode: NormalizationMode, + eps: float = 1e-8, +) -> np.ndarray: + """Apply normalization to data based on the specified mode. + + Args: + data: Data to normalize [N, H, D] or [D] + stats: Dictionary of statistics (mean, std, min, max, q01, q99, q10, q90) + mode: Normalization mode to apply + eps: Small epsilon for numerical stability + + Returns: + Normalized data with the same shape as input + """ + if mode == NormalizationMode.IDENTITY: + return data + + if mode == NormalizationMode.MEAN_STD: + mean = stats.get("mean") + std = stats.get("std") + if mean is None or std is None: + raise ValueError("MEAN_STD mode requires 'mean' and 'std' in stats") + return (data - mean) / np.maximum(std, eps) + + if mode == NormalizationMode.MIN_MAX: + min_val = stats.get("min") + max_val = stats.get("max") + if min_val is None or max_val is None: + raise ValueError("MIN_MAX mode requires 'min' and 'max' in stats") + denom = np.maximum(max_val - min_val, eps) + return 2.0 * (data - min_val) / denom - 1.0 + + if mode == NormalizationMode.QUANTILES: + q01 = stats.get("q01") + q99 = stats.get("q99") + if q01 is None or q99 is None: + raise ValueError("QUANTILES mode requires 'q01' and 'q99' in stats") + denom = np.maximum(q99 - q01, eps) + # Clip to quantile range then normalize to [-1, 1] + clipped = np.clip(data, q01, q99) + return 2.0 * (clipped - q01) / denom - 1.0 + + if mode == NormalizationMode.QUANTILE10: + q10 = stats.get("q10") + q90 = stats.get("q90") + if q10 is None or q90 is None: + raise ValueError("QUANTILE10 mode requires 'q10' and 'q90' in stats") + denom = np.maximum(q90 - q10, eps) + # Clip to quantile range then normalize to [-1, 1] + clipped = np.clip(data, q10, q90) + return 2.0 * (clipped - q10) / denom - 1.0 + + raise ValueError(f"Unsupported normalization mode: {mode}") + + +def process_episode(args): + """Process single episode and return action chunks.""" + dataset, ep_idx, action_horizon, delta_dims, sample_fraction, state_key, use_delta_transform = args + + try: + # get episode info + ep_info = dataset.meta.episodes[ep_idx] + from_idx = ep_info["dataset_from_index"] + to_idx = ep_info["dataset_to_index"] + ep_length = to_idx - from_idx + + if ep_length < action_horizon: + return None + + # load all frames in episode + # if dataset has episode filtering, we need to use the mapping + states = [] + actions = [] + + for abs_idx in range(from_idx, to_idx): + # map absolute index to relative index if needed + if dataset._absolute_to_relative_idx is not None: + if abs_idx not in dataset._absolute_to_relative_idx: + # this episode's frames aren't in the filtered dataset + return None + rel_idx = dataset._absolute_to_relative_idx[abs_idx] + else: + rel_idx = abs_idx + + frame = dataset.hf_dataset[rel_idx] + + # get state (could be from observation.state or other state key) + if state_key in frame: + state = ( + frame[state_key].numpy() + if torch.is_tensor(frame[state_key]) + else np.array(frame[state_key]) + ) + else: + # if no state key, use zeros (no delta transform) + state = np.zeros_like( + frame["action"].numpy() if torch.is_tensor(frame["action"]) else np.array(frame["action"]) + ) + + action = ( + frame["action"].numpy() if torch.is_tensor(frame["action"]) else np.array(frame["action"]) + ) + + states.append(state) + actions.append(action) + + states = np.array(states) + actions = np.array(actions) + + # create action chunks (sliding window) + # all actions in a chunk are relative to the FIRST state in that chunk + action_chunks = [] + + for i in range(len(states) - action_horizon + 1): + current_state = states[i] # First state in chunk + future_absolute_actions = actions[i : i + action_horizon] + + if use_delta_transform: + # relative actions + delta_chunk = np.zeros_like(future_absolute_actions) + for t in range(action_horizon): + delta_chunk[t] = apply_delta_transform( + current_state, + future_absolute_actions[t], + delta_dims, + ) + action_chunks.append(delta_chunk) + else: + # absolute actions (no delta) + action_chunks.append(future_absolute_actions) + + if len(action_chunks) == 0: + return None + + action_chunks = np.array(action_chunks) + + # sample chunks + if sample_fraction < 1.0: + n_chunks = len(action_chunks) + n_samples = max(1, int(n_chunks * sample_fraction)) + episode_seed = hash(ep_idx) % (2**31) + rng = np.random.RandomState(episode_seed) + indices = rng.choice(n_chunks, size=n_samples, replace=False) + action_chunks = action_chunks[indices] + + return action_chunks + + except Exception as e: + print(f"Error processing episode {ep_idx}: {e}") + import traceback + + traceback.print_exc() + return None + + +def train_fast_tokenizer( + action_chunks: np.ndarray, + vocab_size: int = 1024, + scale: float = 10.0, +) -> AutoProcessor: + """ + Train FAST tokenizer (BPE on DCT coefficients) on action chunks. + + Uses the .fit() method to train a new tokenizer on the provided data. + + Args: + action_chunks: Array of action chunks [N, H, D] where N=num_chunks, H=horizon, D=action_dim + vocab_size: BPE vocabulary size + scale: DCT scaling factor for quantization + + Returns: + Trained FAST tokenizer + """ + print(f"Training FAST tokenizer on {len(action_chunks)} action chunks...") + print(f"Action chunk shape: {action_chunks.shape}") + print(f"Vocab size: {vocab_size}") + print(f"DCT scale: {scale}") + + # download the tokenizer source code (not pretrained weights) + # we'll train a new tokenizer on our own data + base_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True) + + # convert action_chunks array to list of arrays (expected by .fit()) + action_data_list = [action_chunks[i] for i in range(len(action_chunks))] + + # train the new tokenizer on our action data using .fit() + # this trains the BPE tokenizer on DCT coefficients + print("Training new tokenizer (this may take a few minutes)...") + tokenizer = base_tokenizer.fit( + action_data_list, + scale=scale, + vocab_size=vocab_size, + time_horizon=action_chunks.shape[1], # action_horizon + action_dim=action_chunks.shape[2], # encoded dimensions + ) + print("✓ Tokenizer training complete!") + + # validate it works + sample_chunk = action_chunks[0] + encoded = tokenizer(sample_chunk[None])[0] + if isinstance(encoded, list): + encoded = np.array(encoded) + print(f"Sample encoding: {len(encoded)} tokens for chunk shape {sample_chunk.shape}") + + return tokenizer + + +def compute_compression_stats(tokenizer, action_chunks: np.ndarray): + """Compute compression statistics.""" + print("\nComputing compression statistics...") + + # sample for stats (use max 1000 chunks for speed) + sample_size = min(1000, len(action_chunks)) + sample_indices = np.random.RandomState(42).choice(len(action_chunks), size=sample_size, replace=False) + sample_chunks = action_chunks[sample_indices] + + token_lengths = [] + for chunk in sample_chunks: + encoded = tokenizer(chunk[None])[0] + if isinstance(encoded, list): + token_lengths.append(len(encoded)) + else: + token_lengths.append(encoded.shape[0] if hasattr(encoded, "shape") else len(encoded)) + + token_lengths = np.array(token_lengths) + + # compression ratio: (H * D) / avg_tokens + input_size = action_chunks.shape[1] * action_chunks.shape[2] + avg_tokens = np.mean(token_lengths) + compression_ratio = input_size / avg_tokens + + stats = { + "compression_ratio": float(compression_ratio), + "mean_token_length": float(np.mean(token_lengths)), + "p99_token_length": float(np.percentile(token_lengths, 99)), + "min_token_length": float(np.min(token_lengths)), + "max_token_length": float(np.max(token_lengths)), + } + + print("Compression Statistics:") + print(f" Average compression ratio: {stats['compression_ratio']:.2f}x") + print(f" Mean token length: {stats['mean_token_length']:.1f}") + print(f" P99 token length: {stats['p99_token_length']:.0f}") + print(f" Min token length: {stats['min_token_length']:.0f}") + print(f" Max token length: {stats['max_token_length']:.0f}") + + return stats + + +def main( + repo_id: str, + root: str | None = None, + action_horizon: int = 10, + max_episodes: int | None = None, + sample_fraction: float = 0.1, + encoded_dims: str = "0:6,7:23", + delta_dims: str | None = None, + use_delta_transform: bool = False, + state_key: str = "observation.state", + normalization_mode: str = "QUANTILES", + vocab_size: int = 1024, + scale: float = 10.0, + output_dir: str | None = None, + push_to_hub: bool = False, + hub_repo_id: str | None = None, + hub_private: bool = False, +): + """ + Train FAST tokenizer for action encoding. + + Args: + repo_id: LeRobot dataset repository ID + root: Root directory for dataset (default: ~/.cache/huggingface/lerobot) + action_horizon: Number of future actions in each chunk + max_episodes: Max episodes to use (None = all episodes in dataset) + sample_fraction: Fraction of chunks to sample per episode + encoded_dims: Comma-separated dimension ranges to encode (e.g., "0:6,7:23") + delta_dims: Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5") + use_delta_transform: Whether to apply delta transform (relative actions vs absolute actions) + state_key: Dataset key for state observations (default: "observation.state") + normalization_mode: Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY) + vocab_size: FAST vocabulary size (BPE vocab size) + scale: DCT scaling factor (default: 10.0) + output_dir: Directory to save tokenizer (default: ./fast_tokenizer_{repo_id}) + push_to_hub: Whether to push the tokenizer to Hugging Face Hub + hub_repo_id: Hub repository ID (e.g., "username/tokenizer-name"). If None, uses output_dir name + hub_private: Whether to create a private repository on the Hub + """ + # load dataset + print(f"Loading dataset: {repo_id}") + dataset = LeRobotDataset(repo_id=repo_id, root=root) + print(f"Dataset loaded: {dataset.num_episodes} episodes, {dataset.num_frames} frames") + + # parse normalization mode + try: + norm_mode = NormalizationMode(normalization_mode) + except ValueError as err: + raise ValueError( + f"Invalid normalization_mode: {normalization_mode}. " + f"Must be one of: {', '.join([m.value for m in NormalizationMode])}" + ) from err + print(f"Normalization mode: {norm_mode.value}") + + # parse encoded dimensions + encoded_dim_ranges = [] + for range_str in encoded_dims.split(","): + start, end = map(int, range_str.strip().split(":")) + encoded_dim_ranges.append((start, end)) + + total_encoded_dims = sum(end - start for start, end in encoded_dim_ranges) + print(f"Encoding {total_encoded_dims} dimensions: {encoded_dims}") + + # parse delta dimensions + delta_dim_list = None + if delta_dims is not None and delta_dims.strip(): + delta_dim_list = [int(d.strip()) for d in delta_dims.split(",")] + print(f"Delta dimensions: {delta_dim_list}") + else: + print("No delta dimensions specified") + + print(f"Use delta transform: {use_delta_transform}") + if use_delta_transform and (delta_dim_list is None or len(delta_dim_list) == 0): + print("Warning: use_delta_transform=True but no delta_dims specified. No delta will be applied.") + + print(f"Action horizon: {action_horizon}") + print(f"State key: {state_key}") + + # determine episodes to process + num_episodes = dataset.num_episodes + if max_episodes is not None: + num_episodes = min(max_episodes, num_episodes) + + print(f"Processing {num_episodes} episodes...") + + # process episodes sequentially (to avoid pickling issues with dataset) + all_chunks = [] + for ep_idx in range(num_episodes): + if ep_idx % 10 == 0: + print(f" Processing episode {ep_idx}/{num_episodes}...") + + chunks = process_episode( + (dataset, ep_idx, action_horizon, delta_dim_list, sample_fraction, state_key, use_delta_transform) + ) + if chunks is not None: + all_chunks.append(chunks) + + # concatenate all chunks + all_chunks = np.concatenate(all_chunks, axis=0) + print(f"Collected {len(all_chunks)} action chunks") + + # extract only encoded dimensions FIRST (before normalization) + encoded_chunks = [] + for start, end in encoded_dim_ranges: + encoded_chunks.append(all_chunks[:, :, start:end]) + encoded_chunks = np.concatenate(encoded_chunks, axis=-1) # [N, H, D_encoded] + print(f"Extracted {encoded_chunks.shape[-1]} encoded dimensions") + + # apply normalization to encoded dimensions + print("\nBefore normalization - overall stats:") + print(f" Min: {np.min(encoded_chunks):.4f}, Max: {np.max(encoded_chunks):.4f}") + print(f" Mean: {np.mean(encoded_chunks):.4f}, Std: {np.std(encoded_chunks):.4f}") + + # get normalization stats from dataset + norm_stats = dataset.meta.stats + if norm_stats is not None and "action" in norm_stats: + action_stats = norm_stats["action"] + + # build encoded dimension indices + encoded_dim_indices = [] + for start, end in encoded_dim_ranges: + encoded_dim_indices.extend(range(start, end)) + encoded_dim_indices = np.array(encoded_dim_indices) + + # extract stats for encoded dimensions only + encoded_stats = {} + for stat_name, stat_values in action_stats.items(): + if isinstance(stat_values, (list, np.ndarray)): + stat_array = np.array(stat_values) + if len(stat_array) > max(encoded_dim_indices): + encoded_stats[stat_name] = stat_array[encoded_dim_indices] + + if encoded_stats: + print(f"\nNormalization stats for encoded dimensions (mode: {norm_mode.value}):") + for stat_name, stat_values in encoded_stats.items(): + print( + f" {stat_name}: shape={stat_values.shape}, " + f"range=[{np.min(stat_values):.4f}, {np.max(stat_values):.4f}]" + ) + + # apply normalization based on mode + try: + encoded_chunks = apply_normalization(encoded_chunks, encoded_stats, norm_mode, eps=1e-8) + print(f"\nApplied {norm_mode.value} normalization") + except ValueError as e: + print(f"Warning: {e}. Using raw actions without normalization.") + + print("\nAfter normalization - overall stats:") + print(f" Min: {np.min(encoded_chunks):.4f}, Max: {np.max(encoded_chunks):.4f}") + print(f" Mean: {np.mean(encoded_chunks):.4f}, Std: {np.std(encoded_chunks):.4f}") + + print("\nPer-dimension stats (after normalization):") + for d in range(encoded_chunks.shape[-1]): + dim_data = encoded_chunks[:, :, d] + print( + f" Dim {d}: min={np.min(dim_data):7.4f}, max={np.max(dim_data):7.4f}, " + f"mean={np.mean(dim_data):7.4f}, std={np.std(dim_data):7.4f}" + ) + else: + print("Warning: Could not extract stats for encoded dimensions, using raw actions") + else: + print("Warning: No normalization stats found in dataset, using raw actions") + + print(f"Encoded chunks shape: {encoded_chunks.shape}") + + # train FAST tokenizer + tokenizer = train_fast_tokenizer( + encoded_chunks, + vocab_size=vocab_size, + scale=scale, + ) + + # compute compression statistics + compression_stats = compute_compression_stats(tokenizer, encoded_chunks) + + # save tokenizer + if output_dir is None: + output_dir = f"fast_tokenizer_{repo_id.replace('/', '_')}" + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + tokenizer.save_pretrained(output_path) + + # save metadata + metadata = { + "repo_id": repo_id, + "vocab_size": vocab_size, + "scale": scale, + "encoded_dims": encoded_dims, + "encoded_dim_ranges": encoded_dim_ranges, + "total_encoded_dims": total_encoded_dims, + "delta_dims": delta_dims, + "delta_dim_list": delta_dim_list, + "use_delta_transform": use_delta_transform, + "state_key": state_key, + "normalization_mode": norm_mode.value, + "action_horizon": action_horizon, + "num_training_chunks": len(encoded_chunks), + "compression_stats": compression_stats, + } + + with open(output_path / "metadata.json", "w") as f: + json.dump(metadata, f, indent=2) + + print(f"\nSaved FAST tokenizer to {output_path}") + print(f"Metadata: {json.dumps(metadata, indent=2)}") + + # push to Hugging Face Hub if requested + if push_to_hub: + # determine the hub repository ID + if hub_repo_id is None: + hub_repo_id = output_path.name + print(f"\nNo hub_repo_id provided, using: {hub_repo_id}") + + print(f"\nPushing tokenizer to Hugging Face Hub: {hub_repo_id}") + print(f" Private: {hub_private}") + + try: + # use the tokenizer's push_to_hub method + tokenizer.push_to_hub( + repo_id=hub_repo_id, + private=hub_private, + commit_message=f"Upload FAST tokenizer trained on {repo_id}", + ) + + # also upload the metadata.json file separately + api = HfApi() + api.upload_file( + path_or_fileobj=str(output_path / "metadata.json"), + path_in_repo="metadata.json", + repo_id=hub_repo_id, + repo_type="model", + commit_message="Upload tokenizer metadata", + ) + + print(f"Successfully pushed tokenizer to: https://huggingface.co/{hub_repo_id}") + except Exception as e: + print(f"Error pushing to hub: {e}") + print(" Make sure you're logged in with `huggingface-cli login`") + + +if __name__ == "__main__": + tyro.cli(main) diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index be11ac1af..676ba29ee 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -75,7 +75,7 @@ from .policy_robot_bridge import ( RobotActionToPolicyActionProcessorStep, ) from .rename_processor import RenameObservationsProcessorStep -from .tokenizer_processor import TokenizerProcessorStep +from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep __all__ = [ "ActionProcessorStep", @@ -122,6 +122,7 @@ __all__ = [ "AddBatchDimensionProcessorStep", "RobotProcessorPipeline", "TokenizerProcessorStep", + "ActionTokenizerProcessorStep", "Torch2NumpyActionProcessorStep", "RobotActionToPolicyActionProcessorStep", "PolicyActionToRobotActionProcessorStep", diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index 2ef89c107..93e0395b9 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -23,22 +23,29 @@ token IDs and attention masks, which are then added to the observation dictionar from __future__ import annotations +import logging from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature -from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS +from lerobot.utils.constants import ( + ACTION_TOKEN_MASK, + ACTION_TOKENS, + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, +) from lerobot.utils.import_utils import _transformers_available from .core import EnvTransition, TransitionKey -from .pipeline import ObservationProcessorStep, ProcessorStepRegistry +from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry # Conditional import for type checking and lazy loading if TYPE_CHECKING or _transformers_available: - from transformers import AutoTokenizer + from transformers import AutoProcessor, AutoTokenizer else: + AutoProcessor = None AutoTokenizer = None @@ -268,3 +275,256 @@ class TokenizerProcessorStep(ObservationProcessorStep): ) return features + + +@dataclass +@ProcessorStepRegistry.register(name="action_tokenizer_processor") +class ActionTokenizerProcessorStep(ActionProcessorStep): + """ + Processor step to tokenize action data using a fast action tokenizer. + + This step takes action tensors from an `EnvTransition`, tokenizes them using + a Hugging Face `transformers` AutoProcessor (such as the Physical Intelligence "fast" tokenizer), + and returns the tokenized action. + + Requires the `transformers` library to be installed. + + Attributes: + tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast"). + tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored. + trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers). + action_tokenizer: The internal tokenizer/processor instance, loaded during initialization. + paligemma_tokenizer_name: The name of a pretrained PaliGemma tokenizer from the Hugging Face Hub (e.g., "google/paligemma-3b-pt-224"). + """ + + action_tokenizer_name: str | None = None + action_tokenizer_input_object: Any | None = None + trust_remote_code: bool = True + max_action_tokens: int = 256 + fast_skip_tokens: int = 128 + paligemma_tokenizer_name: str = "google/paligemma-3b-pt-224" + # Internal tokenizer instance (not part of the config) + action_tokenizer: Any = field(default=None, init=False, repr=False) + _paligemma_tokenizer: Any = field(default=None, init=False, repr=False) + + def __post_init__(self): + """ + Initializes the action tokenizer after the dataclass is created. + + It checks for the availability of the `transformers` library and loads the tokenizer + either from a provided object or by name from the Hugging Face Hub. + + Raises: + ImportError: If the `transformers` library is not installed. + ValueError: If neither `tokenizer` nor `tokenizer_name` is provided. + """ + if not _transformers_available: + raise ImportError( + "The 'transformers' library is not installed. " + "Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionTokenizerProcessorStep." + ) + + if self.action_tokenizer_input_object is not None: + self.action_tokenizer = self.action_tokenizer_input_object + + elif self.action_tokenizer_name is not None: + if AutoProcessor is None: + raise ImportError("AutoProcessor is not available") + self.action_tokenizer = AutoProcessor.from_pretrained( + self.action_tokenizer_name, trust_remote_code=self.trust_remote_code + ) + else: + raise ValueError( + "Either 'action_tokenizer' or 'action_tokenizer_name' must be provided. " + "Pass a tokenizer object directly or a tokenizer name to auto-load." + ) + + self._paligemma_tokenizer = AutoTokenizer.from_pretrained( + self.paligemma_tokenizer_name, + trust_remote_code=self.trust_remote_code, + add_eos_token=True, + add_bos_token=False, + ) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """ + Applies action tokenization to the transition. + + This overrides the base class to handle both tokens and mask. + + Args: + transition: The input transition with action data. + + Returns: + The processed transition with tokenized actions and mask in complementary data. + """ + self._current_transition = transition.copy() + new_transition = self._current_transition + + action = new_transition.get(TransitionKey.ACTION) + if action is None: + # During inference, no action is available, skip tokenization + return new_transition + + # Tokenize and get both tokens and mask + tokens, mask = self._tokenize_action(action) + + # Store mask in complementary data + complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + if complementary_data is None: + complementary_data = {} + complementary_data[ACTION_TOKEN_MASK] = mask + complementary_data[ACTION_TOKENS] = tokens + new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data + return new_transition + + def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor: + """ + Converts action tokens to PaliGemma tokens. + """ + return self._paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens + + def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Tokenizes the action tensor and creates a mask. + + Args: + action: The input action tensor to tokenize. Shape: (B, H, action_dim) or (H, action_dim,) + + Returns: + A tuple of (tokens, mask) where: + - tokens: Tensor of token IDs with shape (B, max_action_tokens) + - mask: Boolean mask with shape (B, max_action_tokens), True for real tokens, False for padding + """ + if action is None: + raise ValueError("Action cannot be None") + + # Get the device and dtype of the input action + device = action.device if isinstance(action, torch.Tensor) else None + + # Handle single sample (add batch dimension) + single_sample = action.dim() == 1 + if single_sample: + action = action.unsqueeze(0) + + batch_size = action.shape[0] + + # Tokenize the action batch + # The fast tokenizer expects action data and returns token IDs + tokens_list = [] + masks_list = [] + + for i in range(batch_size): + # Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy) + action_cpu = action[i : i + 1].cpu() + tokens = self.action_tokenizer(action_cpu) + + # Convert to numpy array if it's a list + if isinstance(tokens, list) or not isinstance(tokens, torch.Tensor): + tokens = torch.tensor(tokens, dtype=torch.long, device=action.device) + else: + # Move tokens back to the same device as input action + tokens = tokens.to(device=action.device) + + # Flatten to 1D if needed + if tokens.dim() > 1: + tokens = tokens.flatten() + + bos_id = self._paligemma_tokenizer.bos_token_id + # add bos + tokens = torch.cat( + [ + torch.tensor([bos_id], device=action.device), + torch.tensor( + self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False), + device=action.device, + ), + self._act_tokens_to_paligemma_tokens(tokens), + torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device), + ] + ) + + # Truncate or pad to max_action_tokens + if len(tokens) > self.max_action_tokens: + logging.warning( + f"Token length ({len(tokens)}) exceeds max length ({self.max_action_tokens}), truncating. " + "Consider increasing the `max_action_tokens` in your model config if this happens frequently." + ) + tokens = tokens[: self.max_action_tokens] + mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device) + else: + mask = torch.cat( + [ + torch.ones(len(tokens), dtype=torch.bool, device=action.device), + torch.zeros( + self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device + ), + ] + ) + # Pad tokens with zeros + tokens = torch.nn.functional.pad(tokens, (0, self.max_action_tokens - len(tokens)), value=0) + + tokens_list.append(tokens) + masks_list.append(mask) + + # Stack into batched tensors + tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens) + masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens) + + # Remove batch dimension if input was single sample + if single_sample: + tokens_batch = tokens_batch.squeeze(0) + masks_batch = masks_batch.squeeze(0) + + # Move to the same device as the input + if device is not None: + tokens_batch = tokens_batch.to(device) + masks_batch = masks_batch.to(device) + + return tokens_batch, masks_batch + + def action(self, action: torch.Tensor) -> torch.Tensor: + """ + This method is not used since we override __call__. + Required by ActionProcessorStep ABC. + """ + tokens, _ = self._tokenize_action(action) + return tokens + + def get_config(self) -> dict[str, Any]: + """ + Returns the serializable configuration of the processor. + + Note: The tokenizer object itself is not serialized. If the processor was initialized + with a tokenizer name, that name will be included in the config. + + Returns: + A dictionary with the processor's configuration parameters. + """ + config = { + "trust_remote_code": self.trust_remote_code, + "max_action_tokens": self.max_action_tokens, + } + + # Only save tokenizer_name if it was used to create the tokenizer + if self.action_tokenizer_name is not None and self.action_tokenizer_input_object is None: + config["action_tokenizer_name"] = self.action_tokenizer_name + + return config + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Updates feature definitions to reflect tokenized actions. + + This updates the policy features dictionary to indicate that the action + has been tokenized into a sequence of token IDs with shape (max_action_tokens,). + + Args: + features: The dictionary of existing policy features. + + Returns: + The updated dictionary of policy features. + """ + return features diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index dfa10b2e5..a96e1596d 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -28,6 +28,8 @@ OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens" OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" ACTION = "action" +ACTION_TOKENS = ACTION + ".tokens" +ACTION_TOKEN_MASK = ACTION + ".token_mask" REWARD = "next.reward" TRUNCATED = "next.truncated" DONE = "next.done" diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index 3a01aee88..e6817ba6c 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -63,6 +63,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b _transformers_available = is_package_available("transformers") _peft_available = is_package_available("peft") +_scipy_available = is_package_available("scipy") def make_device_from_device_class(config: ChoiceRegistry) -> Any: diff --git a/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py b/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py new file mode 100644 index 000000000..9ebc4ba89 --- /dev/null +++ b/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py @@ -0,0 +1,504 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test script to verify PI0Fast policy integration with LeRobot vs the original implementation""" +# ruff: noqa: E402 + +import os +import random +from copy import deepcopy +from typing import Any + +import numpy as np +import pytest +import torch + +pytest.importorskip("transformers") +pytest.importorskip("scipy") +pytestmark = pytest.mark.skipif( + os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", + reason="This test requires accepting the model license", +) + +from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig +from lerobot.policies.pi0_fast.modeling_pi0_fast import PI0FastPolicy +from lerobot.policies.pi0_fast.processor_pi0_fast import make_pi0_fast_pre_post_processors +from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402 +from lerobot.utils.constants import ( + ACTION_TOKEN_MASK, + ACTION_TOKENS, + OBS_IMAGES, + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, + OBS_STATE, +) # noqa: E402 +from tests.utils import require_cuda # noqa: E402 + +# Constants +DUMMY_ACTION_DIM = 7 +DUMMY_STATE_DIM = 20 +IMAGE_HEIGHT = 224 +IMAGE_WIDTH = 224 +NUM_VIEWS = 2 # Number of camera views +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +MODEL_PATH_LEROBOT = "lerobot/pi0fast-base" + +# Expected action token shape: (batch_size, max_decoding_steps) +EXPECTED_ACTION_TOKENS_SHAPE = (1, 2) + +# Expected first 5 action tokens (for reproducibility check) +EXPECTED_ACTION_TOKENS_FIRST_5 = torch.tensor([255657, 255362]) + +# Expected actions after detokenization +EXPECTED_ACTIONS_SHAPE = (1, 2, 32) # (batch_size, n_action_steps, action_dim) +EXPECTED_ACTIONS_MEAN = 0.04419417306780815 +EXPECTED_ACTIONS_STD = 0.26231569051742554 +EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.0000, 1.4849, 0.0000, 0.0000, 0.0000]) + + +def set_seed_all(seed: int): + """Set random seed for all RNG sources to ensure reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + # Set deterministic behavior + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True, warn_only=True) + + +def instantiate_lerobot_pi0_fast( + from_pretrained: bool = False, + model_path: str = MODEL_PATH_LEROBOT, +) -> tuple[ + Any, # Policy + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """Instantiate LeRobot PI0Fast policy with preprocessor and postprocessor.""" + if from_pretrained: + policy = PI0FastPolicy.from_pretrained( + pretrained_name_or_path=model_path, + strict=True, + ) + policy.config.validate_action_token_prefix = False + policy.config.max_action_tokens = 2 + policy.config.max_decoding_steps = 2 + policy.config.chunk_size = 2 + policy.config.n_action_steps = 2 + else: + config = PI0FastConfig( + n_action_steps=2, + max_action_dim=DUMMY_ACTION_DIM, + max_state_dim=DUMMY_STATE_DIM, + device=DEVICE, + validate_action_token_prefix=False, + max_action_tokens=2, + max_decoding_steps=2, + chunk_size=2, + ) + policy = PI0FastPolicy(config) + + policy.to(DEVICE) + policy.config.device = DEVICE + preprocessor, postprocessor = make_pi0_fast_pre_post_processors( + config=policy.config, + dataset_stats=None, # Pass None for dataset_stats to disable normalization + ) + + return policy, preprocessor, postprocessor + + +def create_dummy_data(device=DEVICE): + """Create dummy data for testing both implementations.""" + batch_size = 1 + prompt = "Pick up the red block and place it in the bin" + + # Create random RGB images in [0, 255] uint8 range (as PIL images would be) + # Then convert to [0, 1] float32 range for LeRobot + def fake_rgb(h, w): + arr = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8) + t = torch.from_numpy(arr).permute(2, 0, 1) # CHW + return t + + batch = { + f"{OBS_IMAGES}.base_0_rgb": torch.stack( + [fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)] + ).to(device), + f"{OBS_IMAGES}.left_wrist_0_rgb": torch.stack( + [fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)] + ).to(device), + f"{OBS_IMAGES}.right_wrist_0_rgb": torch.stack( + [fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)] + ).to(device), + OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device), + "task": [prompt for _ in range(batch_size)], + } + + return batch + + +# Pytest fixtures +@pytest.fixture(scope="module") +def pi0_fast_components(): + """Fixture to instantiate and provide all PI0Fast components for tests.""" + print(f"\nTesting with DEVICE='{DEVICE}'") + print("\n[Setup] Instantiating LeRobot PI0Fast policy...") + policy_obj, preprocessor_obj, postprocessor_obj = instantiate_lerobot_pi0_fast(from_pretrained=True) + print("Model loaded successfully") + yield policy_obj, preprocessor_obj, postprocessor_obj + + +@pytest.fixture(scope="module") +def policy(pi0_fast_components): + """Fixture to provide the PI0Fast policy for tests.""" + return pi0_fast_components[0] + + +@pytest.fixture(scope="module") +def preprocessor(pi0_fast_components): + """Fixture to provide the PI0Fast preprocessor for tests.""" + return pi0_fast_components[1] + + +@require_cuda +def test_pi0_fast_preprocessor_alignment(policy, preprocessor): + """Test that LeRobot PI0Fast preprocessor produces expected outputs.""" + print("\n" + "=" * 80) + print("Test: PI0Fast Preprocessor Outputs") + print("=" * 80) + + set_seed_all(42) + + print("\nCreating dummy data...") + batch = create_dummy_data() + + print("\n[LeRobot] Preprocessing...") + lerobot_observation = preprocessor(deepcopy(batch)) + + print("\nVerifying preprocessor outputs:") + print("-" * 80) + + # Expected keys from PI0Fast preprocessing + expected_keys = [ + "observation.images.base_0_rgb", + "observation.images.left_wrist_0_rgb", + "observation.images.right_wrist_0_rgb", + "observation.state", + "observation.language_tokens", + "observation.language_attention_mask", + ] + + for key in expected_keys: + if key in lerobot_observation: + shape = tuple(lerobot_observation[key].shape) + print(f"\nKey: {key}") + print(f"Shape: {shape}") + print(f"Dtype: {lerobot_observation[key].dtype}") + else: + print(f"\nKey '{key}' not found in inputs!") + + # Check language tokens shape + if "observation.language_tokens" in lerobot_observation: + lang_tokens = lerobot_observation["observation.language_tokens"] + print(f"\nLanguage tokens shape: {lang_tokens.shape}") + # Should have batch dimension and max_length from tokenizer + assert lang_tokens.dim() == 2, f"Expected 2D tensor, got {lang_tokens.dim()}D" + + print("\nPreprocessor outputs verified!") + + +@require_cuda +def test_pi0_fast_action_generation(policy, preprocessor): + """Test PI0Fast LeRobot implementation generates expected actions.""" + print("\n" + "=" * 80) + print("Test: PI0Fast Action Generation Against Expected Values") + print("=" * 80) + + set_seed_all(42) + + print("\nCreating dummy data...") + batch = create_dummy_data() + + print("\n[LeRobot] Running inference...") + lerobot_observation = preprocessor(deepcopy(batch)) + + # Reset seed for inference + torch.manual_seed(42) + with torch.no_grad(): + lerobot_actions = policy.predict_action_chunk(lerobot_observation) + lerobot_actions = lerobot_actions.float().cpu() + + print(f"LeRobot actions shape: {lerobot_actions.shape}") + print(f"LeRobot actions mean: {lerobot_actions.mean().item():.6f}") + print(f"LeRobot actions std: {lerobot_actions.std().item():.6f}") + print(f"LeRobot actions first 5: {lerobot_actions[0, 0, :5]}") + + print("\nExpected values (from original PI0Fast):") + print(f"Expected actions shape: {EXPECTED_ACTIONS_SHAPE}") + print(f"Expected actions mean: {EXPECTED_ACTIONS_MEAN:.6f}") + print(f"Expected actions std: {EXPECTED_ACTIONS_STD:.6f}") + print(f"Expected actions first 5: {EXPECTED_ACTIONS_FIRST_5}") + + print("\nAction Comparison:") + print("-" * 80) + + # Compare shapes + actual_shape = tuple(lerobot_actions.shape) + print(f"Actual shape: {actual_shape}") + + assert actual_shape == EXPECTED_ACTIONS_SHAPE, ( + f"Shape mismatch: {actual_shape} vs {EXPECTED_ACTIONS_SHAPE}" + ) + print(f"Shape matches: {actual_shape}") + + # Compare statistics + actual_mean = lerobot_actions.mean().item() + actual_std = lerobot_actions.std().item() + + print(f"\nMean: {actual_mean:.6f} (expected: {EXPECTED_ACTIONS_MEAN:.6f})") + print(f"Std: {actual_std:.6f} (expected: {EXPECTED_ACTIONS_STD:.6f})") + + # Compare first 5 actions + actual_first_5 = lerobot_actions[0, 0, :5] + print("\nFirst 5 actions comparison:") + print(f" Actual: {actual_first_5}") + print(f" Expected: {EXPECTED_ACTIONS_FIRST_5}") + + first_5_diff = torch.abs(actual_first_5 - EXPECTED_ACTIONS_FIRST_5) + print(f" Max diff: {first_5_diff.max().item():.6e}") + print(f" Mean diff: {first_5_diff.mean().item():.6e}") + + # Check with different tolerances + tolerances = [1e-5, 1e-4, 1e-3, 1e-2] + for tol in tolerances: + is_close = torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tol) + status = "Success" if is_close else "Failure" + print(f"{status}: First 5 actions close (atol={tol}): {is_close}") + + # Assert with reasonable tolerance + tolerance = 1e-3 + assert torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tolerance), ( + f"First 5 actions differ by more than tolerance ({tolerance})" + ) + print(f"\nSuccess: Actions match expected values within tolerance ({tolerance})!") + + print("\nAction generation test completed (values printed for reference)!") + + +@require_cuda +def test_pi0_fast_inference_reproducibility(policy, preprocessor): + """Test that PI0Fast inference is reproducible with the same seed.""" + print("\n" + "=" * 80) + print("Test: PI0Fast Inference Reproducibility") + print("=" * 80) + + print("\nCreating dummy data...") + batch = create_dummy_data() + + # First inference + print("\n[Run 1] Running inference...") + set_seed_all(42) + lerobot_observation = preprocessor(deepcopy(batch)) + with torch.no_grad(): + actions_1 = policy.predict_action_chunk(lerobot_observation) + actions_1 = actions_1.float().cpu() + + # Second inference with same seed + print("\n[Run 2] Running inference with same seed...") + set_seed_all(42) + lerobot_observation = preprocessor(deepcopy(batch)) + with torch.no_grad(): + actions_2 = policy.predict_action_chunk(lerobot_observation) + actions_2 = actions_2.float().cpu() + + print("\nComparing two runs:") + print("-" * 80) + if torch.allclose(actions_1, actions_2, atol=1e-8): + print("Inference is perfectly reproducible!") + else: + diff = torch.abs(actions_1 - actions_2) + print("Small differences detected:") + print(f" Max diff: {diff.max().item():.6e}") + print(f" Mean diff: {diff.mean().item():.6e}") + + assert torch.allclose(actions_1, actions_2, atol=1e-6), "Inference should be reproducible!" + + print("\nInference is reproducible!") + + +@require_cuda +def test_pi0_fast_forward_pass_logits(policy, preprocessor): + """Test PI0Fast forward pass and compare logits against expected values.""" + print("\n" + "=" * 80) + print("Test: PI0Fast Forward Pass Logits") + print("=" * 80) + + set_seed_all(42) + + print("\nCreating dummy data with action tokens...") + batch = create_dummy_data() + + # Preprocess the batch + lerobot_observation = preprocessor(deepcopy(batch)) + + # For forward pass, we need action tokens + # Create dummy action tokens for testing + batch_size = 1 + max_action_tokens = policy.config.max_action_tokens + + # Create dummy action tokens (in practice, these come from the FAST tokenizer) + dummy_action_tokens = torch.randint( + 0, 1000, (batch_size, max_action_tokens), dtype=torch.long, device=DEVICE + ) + dummy_action_masks = torch.ones(batch_size, max_action_tokens, dtype=torch.bool, device=DEVICE) + + # Add action tokens to the observation + lerobot_observation[ACTION_TOKENS] = dummy_action_tokens + lerobot_observation[ACTION_TOKEN_MASK] = dummy_action_masks + + print("\n[LeRobot] Running forward pass...") + policy.train() + with torch.no_grad(): + loss, loss_dict = policy.forward(lerobot_observation) + + print(f"Loss: {loss.item():.6f}") + print(f"FAST Loss: {loss_dict['ce_loss']:.6f}") + + print("\nForward pass completed successfully!") + print(f"Loss value: {loss.item():.6f}") + + # The loss should be a positive value + assert loss.item() > 0, "Loss should be positive" + assert not torch.isnan(loss), "Loss should not be NaN" + assert not torch.isinf(loss), "Loss should not be infinite" + + print("\nForward pass test passed!") + + +@require_cuda +def test_pi0_fast_action_token_sampling(policy, preprocessor): + """Test PI0Fast action token sampling (autoregressive decoding).""" + print("\n" + "=" * 80) + print("Test: PI0Fast Action Token Sampling") + print("=" * 80) + + set_seed_all(42) + + print("\nCreating dummy data...") + batch = create_dummy_data() + + print("\n[LeRobot] Preprocessing...") + lerobot_observation = preprocessor(deepcopy(batch)) + + # Prepare inputs for model + images, img_masks = policy._preprocess_images(lerobot_observation) + tokens = lerobot_observation[OBS_LANGUAGE_TOKENS] + masks = lerobot_observation[OBS_LANGUAGE_ATTENTION_MASK] + + print("\n[LeRobot] Sampling action tokens...") + torch.manual_seed(42) + with torch.no_grad(): + action_tokens = policy.model.sample_actions_fast( + images, + img_masks, + tokens, + masks, + max_decoding_steps=2, + temperature=0.0, # Greedy decoding for reproducibility + ) + + print(f"Action tokens shape: {action_tokens.shape}") + print(f"Action tokens first 10: {action_tokens[0, :10].tolist()}") + + print("\nExpected values (from original PI0Fast):") + print(f"Expected shape: {EXPECTED_ACTION_TOKENS_SHAPE}") + print(f"Expected first 5: {EXPECTED_ACTION_TOKENS_FIRST_5.tolist()}") + + # Verify shape + actual_shape = tuple(action_tokens.shape) + print(f"\nActual shape: {actual_shape}") + + assert actual_shape == EXPECTED_ACTION_TOKENS_SHAPE, ( + f"Shape mismatch: {actual_shape} vs {EXPECTED_ACTION_TOKENS_SHAPE}" + ) + + # Compare first 5 tokens + actual_first_5 = action_tokens[0, :5].cpu() + assert torch.equal(actual_first_5, EXPECTED_ACTION_TOKENS_FIRST_5), ( + f"First 5 tokens mismatch: {actual_first_5} vs {EXPECTED_ACTION_TOKENS_FIRST_5}" + ) + + print("\nAction token sampling test completed!") + + +@require_cuda +def test_pi0_fast_detokenization(policy, preprocessor): + """Test PI0Fast action detokenization (FAST decoding).""" + print("\n" + "=" * 80) + print("Test: PI0Fast Action Detokenization") + print("=" * 80) + + set_seed_all(42) + + print("\nCreating dummy data...") + batch = create_dummy_data() + + print("\n[LeRobot] Preprocessing...") + lerobot_observation = preprocessor(deepcopy(batch)) + + # Prepare inputs for model + images, img_masks = policy._preprocess_images(lerobot_observation) + tokens = lerobot_observation[OBS_LANGUAGE_TOKENS] + masks = lerobot_observation[OBS_LANGUAGE_ATTENTION_MASK] + + print("\n[LeRobot] Sampling action tokens...") + torch.manual_seed(42) + with torch.no_grad(): + action_tokens = policy.model.sample_actions_fast( + images, + img_masks, + tokens, + masks, + max_decoding_steps=2, + temperature=0.0, + ) + + print(f"Action tokens shape: {action_tokens.shape}") + + # Detokenize + print("\n[LeRobot] Detokenizing action tokens...") + action_horizon = policy.config.n_action_steps + action_dim = policy.config.output_features["action"].shape[0] + + try: + continuous_actions = policy.detokenize_actions( + action_tokens, action_horizon=action_horizon, action_dim=action_dim + ) + print(f"Continuous actions shape: {continuous_actions.shape}") + print(f"Continuous actions mean: {continuous_actions.mean().item():.6f}") + print(f"Continuous actions std: {continuous_actions.std().item():.6f}") + print(f"Continuous actions first 5: {continuous_actions[0, 0, :5]}") + print("\nDetokenization successful!") + except Exception as e: + print(f"\nDetokenization failed with error: {e}") + print("This may be expected if the action tokens are not valid FAST tokens.") + print("The test will pass as long as the sampling works correctly.") From 91ff9c4975351d39ec3b9659da654bf8bd895cfe Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 12 Jan 2026 12:19:02 +0100 Subject: [PATCH 029/111] Fix: Respect policy.device=cpu config in training (#2778) * fix cpu training in lerobot_train * Update src/lerobot/scripts/lerobot_train.py Signed-off-by: Michel Aractingi --- src/lerobot/scripts/lerobot_train.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 286c69906..927e9659a 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -259,7 +259,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): from accelerate.utils import DistributedDataParallelKwargs ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) - accelerator = Accelerator(step_scheduler_with_optimizer=False, kwargs_handlers=[ddp_kwargs]) + # Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting. + # Force the device to be CPU when policy.device is set to CPU. + force_cpu = cfg.policy.device == "cpu" + accelerator = Accelerator( + step_scheduler_with_optimizer=False, + kwargs_handlers=[ddp_kwargs], + cpu=force_cpu, + ) init_logging(accelerator=accelerator) From 473f1bd0e0fff9d5eeb2687eada5473f94cab43a Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Mon, 12 Jan 2026 13:33:28 +0100 Subject: [PATCH 030/111] docs: improve assets (#2777) * add assets * add libero results pifast: * update * update * update size * update naems: : * update training tokenizer --- README.md | 10 +-- docs/source/groot.mdx | 6 ++ docs/source/pi0.mdx | 6 ++ docs/source/pi0fast.mdx | 68 ++++++++++++++++++- docs/source/sarm.mdx | 6 ++ docs/source/walloss.mdx | 6 ++ pyproject.toml | 1 + .../lerobot_train_tokenizer.py} | 33 +++++++++ 8 files changed, 129 insertions(+), 7 deletions(-) rename src/lerobot/{policies/pi0_fast/train_fast_tokenizer.py => scripts/lerobot_train_tokenizer.py} (94%) diff --git a/README.md b/README.md index 35b28da87..57fec2e5f 100644 --- a/README.md +++ b/README.md @@ -100,11 +100,11 @@ lerobot-train \ --dataset.repo_id=lerobot/aloha_mobile_cabinet ``` -| Category | Models | -| -------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) | -| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) | -| **VLAs Models** | [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) | +| Category | Models | +| -------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| **Imitation Learning** | [ACT](./docs/source/policy_act_README.md), [Diffusion](./docs/source/policy_diffusion_README.md), [VQ-BeT](./docs/source/policy_vqbet_README.md) | +| **Reinforcement Learning** | [HIL-SERL](./docs/source/hilserl.mdx), [TDMPC](./docs/source/policy_tdmpc_README.md) & QC-FQL (coming soon) | +| **VLAs Models** | [Pi0Fast](./docs/source/pi0fast.mdx), [Pi0.5](./docs/source/pi05.mdx), [GR00T N1.5](./docs/source/policy_groot_README.md), [SmolVLA](./docs/source/policy_smolvla_README.md), [XVLA](./docs/source/xvla.mdx) | Similarly to the hardware, you can easily implement your own policy & leverage LeRobot's data collection, training, and visualization tools, and share your model to the HF Hub diff --git a/docs/source/groot.mdx b/docs/source/groot.mdx index 729a64656..6c036c9d5 100644 --- a/docs/source/groot.mdx +++ b/docs/source/groot.mdx @@ -12,6 +12,12 @@ Developers and researchers can post-train GR00T N1.5 with their own real or synt GR00T N1.5 (specifically the GR00T-N1.5-3B model) is built using pre-trained vision and language encoders. It utilizes a flow matching action transformer to model a chunk of actions, conditioned on vision, language, and proprioception. +An overview of GR00T + Its strong performance comes from being trained on an expansive and diverse humanoid dataset, which includes: - Real captured data from robots. diff --git a/docs/source/pi0.mdx b/docs/source/pi0.mdx index 89604b6aa..93e0b4c88 100644 --- a/docs/source/pi0.mdx +++ b/docs/source/pi0.mdx @@ -6,6 +6,12 @@ π₀ represents a breakthrough in robotics as the first general-purpose robot foundation model developed by [Physical Intelligence](https://www.physicalintelligence.company/blog/pi0). Unlike traditional robot programs that are narrow specialists programmed for repetitive motions, π₀ is designed to be a generalist policy that can understand visual inputs, interpret natural language instructions, and control a variety of different robots across diverse tasks. +An overview of Pi0 + ### The Vision for Physical Intelligence As described by Physical Intelligence, while AI has achieved remarkable success in digital domains, from chess-playing to drug discovery, human intelligence still dramatically outpaces AI in the physical world. To paraphrase Moravec's paradox, winning a game of chess represents an "easy" problem for AI, but folding a shirt or cleaning up a table requires solving some of the most difficult engineering problems ever conceived. π₀ represents a first step toward developing artificial physical intelligence that enables users to simply ask robots to perform any task they want, just like they can with large language models. diff --git a/docs/source/pi0fast.mdx b/docs/source/pi0fast.mdx index e64355765..c4230fa79 100644 --- a/docs/source/pi0fast.mdx +++ b/docs/source/pi0fast.mdx @@ -6,6 +6,12 @@ π₀-FAST combines the power of Vision-Language Models with a novel action tokenization approach called **FAST (Frequency-space Action Sequence Tokenization)**. This enables training autoregressive VLAs on highly dexterous tasks that are impossible with standard binning-based discretization, while training **up to 5x faster** than diffusion-based approaches like π₀. +An overview of Pi0-FAST + ### Why FAST? Standard approaches for robot action tokenization use simple per-dimension, per-timestep binning schemes. While passable for simple behaviors, this rapidly breaks down for complex and dexterous skills that require precision and high-frequency control. @@ -53,7 +59,7 @@ You have two options for the FAST tokenizer: ### Training Your Own Tokenizer ```bash -python src/lerobot/policies/pi0_fast/train_fast_tokenizer.py \ +lerobot-train-tokenizer \ --repo_id "user/my-lerobot-dataset" \ --action_horizon 10 \ --encoded_dims "0:6" \ @@ -90,7 +96,7 @@ policy.type=pi0_fast For training π₀-FAST, you can use the LeRobot training script: ```bash -python src/lerobot/scripts/lerobot_train.py \ +lerobot-train \ --dataset.repo_id=your_dataset \ --policy.type=pi0_fast \ --output_dir=./outputs/pi0fast_training \ @@ -171,6 +177,64 @@ The model takes images, text instructions, and robot state as input, and outputs | Inference Method | Iterative Denoising | Autoregressive Decoding | | KV-Caching | N/A | Supported | +## Reproducing π₀Fast results + +We reproduce the results of π₀Fast on the LIBERO benchmark using the LeRobot implementation. We take the LeRobot PiFast base model [lerobot/pi0fast-base](https://huggingface.co/lerobot/pi0fast-base) and finetune for an additional 40kk steps in bfloat16, with batch size of 256 on 8 H100 GPUs using the [HuggingFace LIBERO dataset](https://huggingface.co/datasets/HuggingFaceVLA/libero). + +The finetuned model can be found here: + +- **π₀Fast LIBERO**: [lerobot/pi0fast-libero](https://huggingface.co/lerobot/pi0fast-libero) + +With the following training command: + +```bash +lerobot-train \ + --dataset.repo_id=lerobot/libero \ + --output_dir=outputs/libero_pi0fast \ + --job_name=libero_pi0fast \ + --policy.path=lerobot/pi0fast_base \ + --policy.dtype=bfloat16 \ + --steps=100000 \ + --save_freq=20000 \ + --batch_size=4 \ + --policy.device=cuda \ + --policy.scheduler_warmup_steps=4000 \ + --policy.scheduler_decay_steps=100000 \ + --policy.scheduler_decay_lr=1e-5 \ + --policy.gradient_checkpointing=true \ + --policy.chunk_size=10 \ + --policy.n_action_steps=10 \ + --policy.max_action_tokens=256 \ + --policy.empty_cameras=1 \ +``` + +We then evaluate the finetuned model using the LeRobot LIBERO implementation, by running the following command: + +```bash +tasks="libero_object,libero_spatial,libero_goal,libero_10" +lerobot-eval \ + --policy.path=lerobot/pi0fast-libero \ + --policy.max_action_tokens=256 \ + --env.type=libero \ + --policy.gradient_checkpointing=false \ + --env.task=${tasks} \ + --eval.batch_size=1 \ + --eval.n_episodes=1 \ + --rename_map='{"observation.images.image":"observation.images.base_0_rgb","observation.images.image2":"observation.images.left_wrist_0_rgb"}' +``` + +**Note:** We set `n_action_steps=10`, similar to the original OpenPI implementation. + +### Results + +We obtain the following results on the LIBERO benchmark: + +| Model | LIBERO Spatial | LIBERO Object | LIBERO Goal | LIBERO 10 | Average | +| ----------- | -------------- | ------------- | ----------- | --------- | -------- | +| **π₀-fast** | 70.0 | 100.0 | 100.0 | 60.0 | **82.5** | + +The full evaluation output folder, including videos, is available [here](https://drive.google.com/drive/folders/1HXpwPTRm4hx6g1sF2P7OOqGG0TwPU7LQ?usp=sharing) + ## License This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi). diff --git a/docs/source/sarm.mdx b/docs/source/sarm.mdx index 321097692..65e49792b 100644 --- a/docs/source/sarm.mdx +++ b/docs/source/sarm.mdx @@ -4,6 +4,12 @@ SARM (Stage-Aware Reward Modeling) is a video-based reward modeling framework fo **Paper**: [SARM: Stage-Aware Reward Modeling for Long Horizon Robot Manipulation](https://arxiv.org/abs/2509.25358) +An overview of SARM + ## Why Reward Models? Standard behavior cloning treats all demonstration frames equally, but real-world robot datasets are messy. They contain hesitations, corrections, and variable-quality trajectories. Reward models solve this by learning a generalizable notion of **task progress** from demonstrations: given video frames and a task description, they predict how close the robot is to completing the task (0→1). This learned "progress signal" can be used in multiple ways, two promising applications are: (1) **weighted imitation learning** (RA-BC), where high-progress frames receive more weight during policy training, and (2) **reinforcement learning**, where the reward model provides dense rewards for online or offline policy improvement. diff --git a/docs/source/walloss.mdx b/docs/source/walloss.mdx index 12e9b1fc7..c0756c087 100644 --- a/docs/source/walloss.mdx +++ b/docs/source/walloss.mdx @@ -8,6 +8,12 @@ X Square Robot’s WALL-OSS is now integrated into Hugging Face’s LeRobot ecos The WALL-OSS team is building the embodied foundation model to capture and compress the world's most valuable data: the continuous, high-fidelity stream of physical interaction. By creating a direct feedback loop between the model's decisions and the body's lived experience, the emergence of a truly generalizable intelligence is enabled—one that understands not just how the world works, but how to act effectively within it. +An overview of WALL-OSS + Technically, WALL-OSS introduces a tightly coupled multimodal architecture (tightly-coupled MoE structure) that integrates both discrete and continuous action modeling strategies. Through a two-stage training pipeline (Inspiration → Integration), the model gradually unifies semantic reasoning and high-frequency action generation. Its core innovations include: - **Embodied perception–enhanced multimodal pretraining**: Large-scale training on unified vision–language–action data to strengthen spatial, causal, and manipulation understanding. diff --git a/pyproject.toml b/pyproject.toml index 75738d2de..e8f334c77 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -197,6 +197,7 @@ lerobot-setup-motors="lerobot.scripts.lerobot_setup_motors:main" lerobot-teleoperate="lerobot.scripts.lerobot_teleoperate:main" lerobot-eval="lerobot.scripts.lerobot_eval:main" lerobot-train="lerobot.scripts.lerobot_train:main" +lerobot-train-tokenizer="lerobot.scripts.lerobot_train_tokenizer:main" lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main" lerobot-info="lerobot.scripts.lerobot_info:main" lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main" diff --git a/src/lerobot/policies/pi0_fast/train_fast_tokenizer.py b/src/lerobot/scripts/lerobot_train_tokenizer.py similarity index 94% rename from src/lerobot/policies/pi0_fast/train_fast_tokenizer.py rename to src/lerobot/scripts/lerobot_train_tokenizer.py index 6a3a1fe69..03bfcaaf8 100644 --- a/src/lerobot/policies/pi0_fast/train_fast_tokenizer.py +++ b/src/lerobot/scripts/lerobot_train_tokenizer.py @@ -1,3 +1,16 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """Train FAST tokenizer for action encoding. This script: @@ -6,6 +19,26 @@ This script: 3. Trains FAST tokenizer on specified action dimensions 4. Saves tokenizer to assets directory 5. Reports compression statistics + +Example: + +```shell +lerobot-train-tokenizer \ + --repo_id=user/dataset_name \ + --action_horizon=10 \ + --max_episodes=100 \ + --sample_fraction=0.1 \ + --encoded_dims="0:6" \ + --delta_dims="0,1,2,3,4,5" \ + --use_delta_transform=true \ + --state_key="observation.state" \ + --normalization_mode="QUANTILES" \ + --vocab_size=1024 \ + --scale=10.0 \ + --output_dir="./fast_tokenizer_dataset_name" \ + --push_to_hub=true \ + --hub_repo_id="user/fast_tokenizer_dataset_name" \ + --hub_private=false """ import json From d791a431fe87c0786d36802978968b3e22a43cab Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 12 Jan 2026 16:01:22 +0100 Subject: [PATCH 031/111] feat(robots): consolidates bi SO setups (#2780) * feat(robots): consolidates bi SO setups * fix(robots): solve circular dependecy * fix(robots): teleop & record working * feat(robots): only one SO * fix(utils): rename bi so * fix(scripts): bi so import * fix(rl): remove imports --- docs/source/envhub_leisaac.mdx | 1 + docs/source/groot.mdx | 2 +- examples/rtc/eval_with_real_robot.py | 1 + examples/so100_to_so100_EE/record.py | 3 +- examples/so100_to_so100_EE/teleoperate.py | 3 +- src/lerobot/async_inference/constants.py | 2 +- src/lerobot/async_inference/robot_client.py | 2 +- .../config_bi_so100_follower.py | 39 -------- .../bi_so_follower}/__init__.py | 6 +- .../bi_so_follower.py} | 91 +++++++++---------- .../bi_so_follower/config_bi_so_follower.py} | 14 ++- src/lerobot/robots/so_follower/__init__.py | 14 +-- ...r_config_base.py => config_so_follower.py} | 14 ++- src/lerobot/robots/so_follower/so100.md | 1 + .../so100_follower/config_so100_follower.py | 26 ------ .../so_follower/so100_follower/so100.md | 1 - .../so100_follower/so100_follower.py | 27 ------ src/lerobot/robots/so_follower/so101.md | 1 + .../so101_follower/config_so101_follower.py | 26 ------ .../so_follower/so101_follower/so101.md | 1 - .../so101_follower/so101_follower.py | 27 ------ .../{so_follower_base.py => so_follower.py} | 15 ++- src/lerobot/robots/utils.py | 6 +- src/lerobot/scripts/lerobot_calibrate.py | 2 + .../scripts/lerobot_find_joint_limits.py | 2 + src/lerobot/scripts/lerobot_record.py | 28 +++--- src/lerobot/scripts/lerobot_replay.py | 4 +- src/lerobot/scripts/lerobot_setup_motors.py | 2 + src/lerobot/scripts/lerobot_teleoperate.py | 24 ++--- .../bi_so_leader}/__init__.py | 5 +- .../bi_so_leader.py} | 51 +++++------ .../config_bi_so_leader.py} | 12 ++- .../teleoperators/so_leader/__init__.py | 13 +-- ...der_config_base.py => config_so_leader.py} | 14 ++- src/lerobot/teleoperators/so_leader/so100.md | 1 + .../so100_leader/config_so100_leader.py | 26 ------ .../so_leader/so100_leader/so100.md | 1 - .../so_leader/so100_leader/so100_leader.py | 27 ------ src/lerobot/teleoperators/so_leader/so101.md | 1 + .../so_leader/so101_leader/so101.md | 1 - .../so_leader/so101_leader/so101_leader.py | 27 ------ .../{so_leader_base.py => so_leader.py} | 14 ++- src/lerobot/teleoperators/utils.py | 6 +- tests/robots/test_so100_follower.py | 2 +- 44 files changed, 199 insertions(+), 387 deletions(-) delete mode 100644 src/lerobot/robots/bi_so100_follower/config_bi_so100_follower.py rename src/lerobot/{teleoperators/bi_so100_leader => robots/bi_so_follower}/__init__.py (77%) rename src/lerobot/robots/{bi_so100_follower/bi_so100_follower.py => bi_so_follower/bi_so_follower.py} (53%) rename src/lerobot/{teleoperators/so_leader/so101_leader/config_so101_leader.py => robots/bi_so_follower/config_bi_so_follower.py} (68%) rename src/lerobot/robots/so_follower/{so_follower_config_base.py => config_so_follower.py} (80%) create mode 120000 src/lerobot/robots/so_follower/so100.md delete mode 100644 src/lerobot/robots/so_follower/so100_follower/config_so100_follower.py delete mode 120000 src/lerobot/robots/so_follower/so100_follower/so100.md delete mode 100644 src/lerobot/robots/so_follower/so100_follower/so100_follower.py create mode 120000 src/lerobot/robots/so_follower/so101.md delete mode 100644 src/lerobot/robots/so_follower/so101_follower/config_so101_follower.py delete mode 120000 src/lerobot/robots/so_follower/so101_follower/so101.md delete mode 100644 src/lerobot/robots/so_follower/so101_follower/so101_follower.py rename src/lerobot/robots/so_follower/{so_follower_base.py => so_follower.py} (96%) rename src/lerobot/{robots/bi_so100_follower => teleoperators/bi_so_leader}/__init__.py (76%) rename src/lerobot/teleoperators/{bi_so100_leader/bi_so100_leader.py => bi_so_leader/bi_so_leader.py} (62%) rename src/lerobot/teleoperators/{bi_so100_leader/config_bi_so100_leader.py => bi_so_leader/config_bi_so_leader.py} (71%) rename src/lerobot/teleoperators/so_leader/{so_leader_config_base.py => config_so_leader.py} (72%) create mode 120000 src/lerobot/teleoperators/so_leader/so100.md delete mode 100644 src/lerobot/teleoperators/so_leader/so100_leader/config_so100_leader.py delete mode 120000 src/lerobot/teleoperators/so_leader/so100_leader/so100.md delete mode 100644 src/lerobot/teleoperators/so_leader/so100_leader/so100_leader.py create mode 120000 src/lerobot/teleoperators/so_leader/so101.md delete mode 120000 src/lerobot/teleoperators/so_leader/so101_leader/so101.md delete mode 100644 src/lerobot/teleoperators/so_leader/so101_leader/so101_leader.py rename src/lerobot/teleoperators/so_leader/{so_leader_base.py => so_leader.py} (95%) diff --git a/docs/source/envhub_leisaac.mdx b/docs/source/envhub_leisaac.mdx index 1fc74c0fa..2537700a5 100644 --- a/docs/source/envhub_leisaac.mdx +++ b/docs/source/envhub_leisaac.mdx @@ -138,6 +138,7 @@ from lerobot.teleoperators import ( # noqa: F401 TeleoperatorConfig, make_teleoperator_from_config, so_leader, + bi_so_leader, ) from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import init_logging diff --git a/docs/source/groot.mdx b/docs/source/groot.mdx index 6c036c9d5..8bfc22996 100644 --- a/docs/source/groot.mdx +++ b/docs/source/groot.mdx @@ -109,7 +109,7 @@ Once you have trained your model using your parameters you can run inference in ```bash lerobot-record \ - --robot.type=bi_so100_follower \ + --robot.type=bi_so_follower \ --robot.left_arm_port=/dev/ttyACM1 \ --robot.right_arm_port=/dev/ttyACM0 \ --robot.id=bimanual_follower \ diff --git a/examples/rtc/eval_with_real_robot.py b/examples/rtc/eval_with_real_robot.py index 991d2468d..1470899d9 100644 --- a/examples/rtc/eval_with_real_robot.py +++ b/examples/rtc/eval_with_real_robot.py @@ -94,6 +94,7 @@ from lerobot.rl.process import ProcessSignalHandler from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + bi_so_follower, koch_follower, so_follower, ) diff --git a/examples/so100_to_so100_EE/record.py b/examples/so100_to_so100_EE/record.py index db24f4b93..eead7a9a8 100644 --- a/examples/so100_to_so100_EE/record.py +++ b/examples/so100_to_so100_EE/record.py @@ -34,8 +34,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( InverseKinematicsEEToJoints, ) from lerobot.scripts.lerobot_record import record_loop -from lerobot.teleoperators.so_leader import SO100LeaderConfig -from lerobot.teleoperators.so_leader.so100_leader import SO100Leader +from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig from lerobot.utils.control_utils import init_keyboard_listener from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun diff --git a/examples/so100_to_so100_EE/teleoperate.py b/examples/so100_to_so100_EE/teleoperate.py index d520a6eaf..71d2899de 100644 --- a/examples/so100_to_so100_EE/teleoperate.py +++ b/examples/so100_to_so100_EE/teleoperate.py @@ -29,8 +29,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import ( ForwardKinematicsJointsToEE, InverseKinematicsEEToJoints, ) -from lerobot.teleoperators.so_leader import SO100LeaderConfig -from lerobot.teleoperators.so_leader.so100_leader import SO100Leader +from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.visualization_utils import init_rerun, log_rerun_data diff --git a/src/lerobot/async_inference/constants.py b/src/lerobot/async_inference/constants.py index f8b6d7bb3..081db0504 100644 --- a/src/lerobot/async_inference/constants.py +++ b/src/lerobot/async_inference/constants.py @@ -26,4 +26,4 @@ DEFAULT_OBS_QUEUE_TIMEOUT = 2 SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"] # TODO: Add all other robots -SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so100_follower", "omx_follower"] +SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so_follower", "omx_follower"] diff --git a/src/lerobot/async_inference/robot_client.py b/src/lerobot/async_inference/robot_client.py index c3668d40b..eea5585b0 100644 --- a/src/lerobot/async_inference/robot_client.py +++ b/src/lerobot/async_inference/robot_client.py @@ -51,7 +51,7 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, - bi_so100_follower, + bi_so_follower, koch_follower, make_robot_from_config, omx_follower, diff --git a/src/lerobot/robots/bi_so100_follower/config_bi_so100_follower.py b/src/lerobot/robots/bi_so100_follower/config_bi_so100_follower.py deleted file mode 100644 index 5806d7415..000000000 --- a/src/lerobot/robots/bi_so100_follower/config_bi_so100_follower.py +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass, field - -from lerobot.cameras import CameraConfig - -from ..config import RobotConfig - - -@RobotConfig.register_subclass("bi_so100_follower") -@dataclass -class BiSO100FollowerConfig(RobotConfig): - left_arm_port: str - right_arm_port: str - - # Optional - left_arm_disable_torque_on_disconnect: bool = True - left_arm_max_relative_target: float | dict[str, float] | None = None - left_arm_use_degrees: bool = False - right_arm_disable_torque_on_disconnect: bool = True - right_arm_max_relative_target: float | dict[str, float] | None = None - right_arm_use_degrees: bool = False - - # cameras (shared between both arms) - cameras: dict[str, CameraConfig] = field(default_factory=dict) diff --git a/src/lerobot/teleoperators/bi_so100_leader/__init__.py b/src/lerobot/robots/bi_so_follower/__init__.py similarity index 77% rename from src/lerobot/teleoperators/bi_so100_leader/__init__.py rename to src/lerobot/robots/bi_so_follower/__init__.py index 34313a61e..f631a14db 100644 --- a/src/lerobot/teleoperators/bi_so100_leader/__init__.py +++ b/src/lerobot/robots/bi_so_follower/__init__.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,5 +14,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .bi_so100_leader import BiSO100Leader -from .config_bi_so100_leader import BiSO100LeaderConfig +from .bi_so_follower import BiSOFollower +from .config_bi_so_follower import BiSOFollowerConfig diff --git a/src/lerobot/robots/bi_so100_follower/bi_so100_follower.py b/src/lerobot/robots/bi_so_follower/bi_so_follower.py similarity index 53% rename from src/lerobot/robots/bi_so100_follower/bi_so100_follower.py rename to src/lerobot/robots/bi_so_follower/bi_so_follower.py index 87a7edcc5..fa81e7d09 100644 --- a/src/lerobot/robots/bi_so100_follower/bi_so100_follower.py +++ b/src/lerobot/robots/bi_so_follower/bi_so_follower.py @@ -15,66 +15,73 @@ # limitations under the License. import logging -import time from functools import cached_property from typing import Any -from lerobot.cameras.utils import make_cameras_from_configs -from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig +from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig from ..robot import Robot -from .config_bi_so100_follower import BiSO100FollowerConfig +from .config_bi_so_follower import BiSOFollowerConfig logger = logging.getLogger(__name__) -class BiSO100Follower(Robot): +class BiSOFollower(Robot): """ - [Bimanual SO-100 Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio - This bimanual robot can also be easily adapted to use SO-101 follower arms, just replace the SO100Follower class with SO101Follower and SO100FollowerConfig with SO101FollowerConfig. + [Bimanual SO Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio """ - config_class = BiSO100FollowerConfig - name = "bi_so100_follower" + config_class = BiSOFollowerConfig + name = "bi_so_follower" - def __init__(self, config: BiSO100FollowerConfig): + def __init__(self, config: BiSOFollowerConfig): super().__init__(config) self.config = config - left_arm_config = SO100FollowerConfig( + left_arm_config = SOFollowerRobotConfig( id=f"{config.id}_left" if config.id else None, calibration_dir=config.calibration_dir, - port=config.left_arm_port, - disable_torque_on_disconnect=config.left_arm_disable_torque_on_disconnect, - max_relative_target=config.left_arm_max_relative_target, - use_degrees=config.left_arm_use_degrees, - cameras={}, + port=config.left_arm_config.port, + disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect, + max_relative_target=config.left_arm_config.max_relative_target, + use_degrees=config.left_arm_config.use_degrees, + cameras=config.left_arm_config.cameras, ) - right_arm_config = SO100FollowerConfig( + right_arm_config = SOFollowerRobotConfig( id=f"{config.id}_right" if config.id else None, calibration_dir=config.calibration_dir, - port=config.right_arm_port, - disable_torque_on_disconnect=config.right_arm_disable_torque_on_disconnect, - max_relative_target=config.right_arm_max_relative_target, - use_degrees=config.right_arm_use_degrees, - cameras={}, + port=config.right_arm_config.port, + disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect, + max_relative_target=config.right_arm_config.max_relative_target, + use_degrees=config.right_arm_config.use_degrees, + cameras=config.right_arm_config.cameras, ) - self.left_arm = SO100Follower(left_arm_config) - self.right_arm = SO100Follower(right_arm_config) - self.cameras = make_cameras_from_configs(config.cameras) + self.left_arm = SOFollower(left_arm_config) + self.right_arm = SOFollower(right_arm_config) + + # Only for compatibility with other parts of the codebase that expect a `robot.cameras` attribute + self.cameras = {**self.left_arm.cameras, **self.right_arm.cameras} @property def _motors_ft(self) -> dict[str, type]: - return {f"left_{motor}.pos": float for motor in self.left_arm.bus.motors} | { - f"right_{motor}.pos": float for motor in self.right_arm.bus.motors + left_arm_motors_ft = self.left_arm._motors_ft + right_arm_motors_ft = self.right_arm._motors_ft + + return { + **{f"left_{k}": v for k, v in left_arm_motors_ft.items()}, + **{f"right_{k}": v for k, v in right_arm_motors_ft.items()}, } @property def _cameras_ft(self) -> dict[str, tuple]: + left_arm_cameras_ft = self.left_arm._cameras_ft + right_arm_cameras_ft = self.right_arm._cameras_ft + return { - cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras + **{f"left_{k}": v for k, v in left_arm_cameras_ft.items()}, + **{f"right_{k}": v for k, v in right_arm_cameras_ft.items()}, } @cached_property @@ -87,19 +94,12 @@ class BiSO100Follower(Robot): @property def is_connected(self) -> bool: - return ( - self.left_arm.bus.is_connected - and self.right_arm.bus.is_connected - and all(cam.is_connected for cam in self.cameras.values()) - ) + return self.left_arm.is_connected and self.right_arm.is_connected def connect(self, calibrate: bool = True) -> None: self.left_arm.connect(calibrate) self.right_arm.connect(calibrate) - for cam in self.cameras.values(): - cam.connect() - @property def is_calibrated(self) -> bool: return self.left_arm.is_calibrated and self.right_arm.is_calibrated @@ -127,12 +127,6 @@ class BiSO100Follower(Robot): right_obs = self.right_arm.get_observation() obs_dict.update({f"right_{key}": value for key, value in right_obs.items()}) - for cam_key, cam in self.cameras.items(): - start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() - dt_ms = (time.perf_counter() - start) * 1e3 - logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") - return obs_dict def send_action(self, action: dict[str, Any]) -> dict[str, Any]: @@ -145,18 +139,15 @@ class BiSO100Follower(Robot): key.removeprefix("right_"): value for key, value in action.items() if key.startswith("right_") } - send_action_left = self.left_arm.send_action(left_action) - send_action_right = self.right_arm.send_action(right_action) + sent_action_left = self.left_arm.send_action(left_action) + sent_action_right = self.right_arm.send_action(right_action) # Add prefixes back - prefixed_send_action_left = {f"left_{key}": value for key, value in send_action_left.items()} - prefixed_send_action_right = {f"right_{key}": value for key, value in send_action_right.items()} + prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()} + prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()} - return {**prefixed_send_action_left, **prefixed_send_action_right} + return {**prefixed_sent_action_left, **prefixed_sent_action_right} def disconnect(self): self.left_arm.disconnect() self.right_arm.disconnect() - - for cam in self.cameras.values(): - cam.disconnect() diff --git a/src/lerobot/teleoperators/so_leader/so101_leader/config_so101_leader.py b/src/lerobot/robots/bi_so_follower/config_bi_so_follower.py similarity index 68% rename from src/lerobot/teleoperators/so_leader/so101_leader/config_so101_leader.py rename to src/lerobot/robots/bi_so_follower/config_bi_so_follower.py index 1a6bb0df4..dca74fa2d 100644 --- a/src/lerobot/teleoperators/so_leader/so101_leader/config_so101_leader.py +++ b/src/lerobot/robots/bi_so_follower/config_bi_so_follower.py @@ -16,11 +16,15 @@ from dataclasses import dataclass -from ...config import TeleoperatorConfig -from ..so_leader_config_base import SOLeaderConfigBase +from lerobot.robots.so_follower import SOFollowerConfig + +from ..config import RobotConfig -@TeleoperatorConfig.register_subclass("so101_leader") +@RobotConfig.register_subclass("bi_so_follower") @dataclass -class SO101LeaderConfig(SOLeaderConfigBase): - pass +class BiSOFollowerConfig(RobotConfig): + """Configuration class for Bi SO Follower robots.""" + + left_arm_config: SOFollowerConfig + right_arm_config: SOFollowerConfig diff --git a/src/lerobot/robots/so_follower/__init__.py b/src/lerobot/robots/so_follower/__init__.py index 82755250c..eea2fcbdf 100644 --- a/src/lerobot/robots/so_follower/__init__.py +++ b/src/lerobot/robots/so_follower/__init__.py @@ -14,10 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from .so100_follower.config_so100_follower import SO100FollowerConfig -from .so100_follower.so100_follower import SO100Follower -from .so101_follower.config_so101_follower import SO101FollowerConfig -from .so101_follower.so101_follower import SO101Follower -from .so_follower_base import SOFollowerBase -from .so_follower_config_base import SOFollowerConfigBase +from .config_so_follower import ( + SO100FollowerConfig, + SO101FollowerConfig, + SOFollowerConfig, + SOFollowerRobotConfig, +) +from .so_follower import SO100Follower, SO101Follower, SOFollower diff --git a/src/lerobot/robots/so_follower/so_follower_config_base.py b/src/lerobot/robots/so_follower/config_so_follower.py similarity index 80% rename from src/lerobot/robots/so_follower/so_follower_config_base.py rename to src/lerobot/robots/so_follower/config_so_follower.py index 6e42df278..e9ce27123 100644 --- a/src/lerobot/robots/so_follower/so_follower_config_base.py +++ b/src/lerobot/robots/so_follower/config_so_follower.py @@ -15,6 +15,7 @@ # limitations under the License. from dataclasses import dataclass, field +from typing import TypeAlias from lerobot.cameras import CameraConfig @@ -22,7 +23,7 @@ from ..config import RobotConfig @dataclass -class SOFollowerConfigBase(RobotConfig): +class SOFollowerConfig: """Base configuration class for SO Follower robots.""" # Port to connect to the arm @@ -40,3 +41,14 @@ class SOFollowerConfigBase(RobotConfig): # Set to `True` for backward compatibility with previous policies/dataset use_degrees: bool = False + + +@RobotConfig.register_subclass("so101_follower") +@RobotConfig.register_subclass("so100_follower") +@dataclass +class SOFollowerRobotConfig(RobotConfig, SOFollowerConfig): + pass + + +SO100FollowerConfig: TypeAlias = SOFollowerRobotConfig +SO101FollowerConfig: TypeAlias = SOFollowerRobotConfig diff --git a/src/lerobot/robots/so_follower/so100.md b/src/lerobot/robots/so_follower/so100.md new file mode 120000 index 000000000..ad1154e75 --- /dev/null +++ b/src/lerobot/robots/so_follower/so100.md @@ -0,0 +1 @@ +../../../../docs/source/so100.mdx \ No newline at end of file diff --git a/src/lerobot/robots/so_follower/so100_follower/config_so100_follower.py b/src/lerobot/robots/so_follower/so100_follower/config_so100_follower.py deleted file mode 100644 index 3e8dd7e9f..000000000 --- a/src/lerobot/robots/so_follower/so100_follower/config_so100_follower.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass - -from ...config import RobotConfig -from ..so_follower_config_base import SOFollowerConfigBase - - -@RobotConfig.register_subclass("so100_follower") -@dataclass -class SO100FollowerConfig(SOFollowerConfigBase): - pass diff --git a/src/lerobot/robots/so_follower/so100_follower/so100.md b/src/lerobot/robots/so_follower/so100_follower/so100.md deleted file mode 120000 index f06f88ff6..000000000 --- a/src/lerobot/robots/so_follower/so100_follower/so100.md +++ /dev/null @@ -1 +0,0 @@ -../../../../../docs/source/so100.mdx \ No newline at end of file diff --git a/src/lerobot/robots/so_follower/so100_follower/so100_follower.py b/src/lerobot/robots/so_follower/so100_follower/so100_follower.py deleted file mode 100644 index ef61f5ce3..000000000 --- a/src/lerobot/robots/so_follower/so100_follower/so100_follower.py +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ..so_follower_base import SOFollowerBase -from .config_so100_follower import SO100FollowerConfig - - -class SO100Follower(SOFollowerBase): - """ - SO-101 follower robot class. [SO-101 Follower Arm](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio - """ - - config_class = SO100FollowerConfig - name = "so100_follower" diff --git a/src/lerobot/robots/so_follower/so101.md b/src/lerobot/robots/so_follower/so101.md new file mode 120000 index 000000000..27b892660 --- /dev/null +++ b/src/lerobot/robots/so_follower/so101.md @@ -0,0 +1 @@ +../../../../docs/source/so101.mdx \ No newline at end of file diff --git a/src/lerobot/robots/so_follower/so101_follower/config_so101_follower.py b/src/lerobot/robots/so_follower/so101_follower/config_so101_follower.py deleted file mode 100644 index 950e8c839..000000000 --- a/src/lerobot/robots/so_follower/so101_follower/config_so101_follower.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass - -from ...config import RobotConfig -from ..so_follower_config_base import SOFollowerConfigBase - - -@RobotConfig.register_subclass("so101_follower") -@dataclass -class SO101FollowerConfig(SOFollowerConfigBase): - pass diff --git a/src/lerobot/robots/so_follower/so101_follower/so101.md b/src/lerobot/robots/so_follower/so101_follower/so101.md deleted file mode 120000 index 38f4deca7..000000000 --- a/src/lerobot/robots/so_follower/so101_follower/so101.md +++ /dev/null @@ -1 +0,0 @@ -../../../../../docs/source/so101.mdx \ No newline at end of file diff --git a/src/lerobot/robots/so_follower/so101_follower/so101_follower.py b/src/lerobot/robots/so_follower/so101_follower/so101_follower.py deleted file mode 100644 index b4cfb2711..000000000 --- a/src/lerobot/robots/so_follower/so101_follower/so101_follower.py +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ..so_follower_base import SOFollowerBase -from .config_so101_follower import SO101FollowerConfig - - -class SO101Follower(SOFollowerBase): - """ - SO-101 follower robot class. [SO-101 Follower Arm](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio - """ - - config_class = SO101FollowerConfig - name = "so101_follower" diff --git a/src/lerobot/robots/so_follower/so_follower_base.py b/src/lerobot/robots/so_follower/so_follower.py similarity index 96% rename from src/lerobot/robots/so_follower/so_follower_base.py rename to src/lerobot/robots/so_follower/so_follower.py index 5e6e32888..5e99b33a1 100644 --- a/src/lerobot/robots/so_follower/so_follower_base.py +++ b/src/lerobot/robots/so_follower/so_follower.py @@ -17,7 +17,7 @@ import logging import time from functools import cached_property -from typing import Any +from typing import Any, TypeAlias from lerobot.cameras.utils import make_cameras_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode @@ -29,20 +29,21 @@ from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnected from ..robot import Robot from ..utils import ensure_safe_goal_position -from .so_follower_config_base import SOFollowerConfigBase +from .config_so_follower import SOFollowerRobotConfig logger = logging.getLogger(__name__) -class SOFollowerBase(Robot): +class SOFollower(Robot): """ Generic SO follower base implementing common functionality for SO-100/101/10X. Designed to be subclassed with a per-hardware-model `config_class` and `name`. """ - # `config_class` and `name` should be set by subclasses + config_class = SOFollowerRobotConfig + name = "so_follower" - def __init__(self, config: SOFollowerConfigBase): + def __init__(self, config: SOFollowerRobotConfig): super().__init__(config) self.config = config # choose normalization mode depending on config if available @@ -232,3 +233,7 @@ class SOFollowerBase(Robot): cam.disconnect() logger.info(f"{self} disconnected.") + + +SO100Follower: TypeAlias = SOFollower +SO101Follower: TypeAlias = SOFollower diff --git a/src/lerobot/robots/utils.py b/src/lerobot/robots/utils.py index ad6cc3da1..27abaaa86 100644 --- a/src/lerobot/robots/utils.py +++ b/src/lerobot/robots/utils.py @@ -52,10 +52,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot: from .hope_jr import HopeJrArm return HopeJrArm(config) - elif config.type == "bi_so100_follower": - from .bi_so100_follower import BiSO100Follower + elif config.type == "bi_so_follower": + from .bi_so_follower import BiSOFollower - return BiSO100Follower(config) + return BiSOFollower(config) elif config.type == "reachy2": from .reachy2 import Reachy2Robot diff --git a/src/lerobot/scripts/lerobot_calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py index 1468c6e93..cbc7684d3 100644 --- a/src/lerobot/scripts/lerobot_calibrate.py +++ b/src/lerobot/scripts/lerobot_calibrate.py @@ -36,6 +36,7 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + bi_so_follower, hope_jr, koch_follower, lekiwi, @@ -46,6 +47,7 @@ from lerobot.robots import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401 Teleoperator, TeleoperatorConfig, + bi_so_leader, homunculus, koch_leader, make_teleoperator_from_config, diff --git a/src/lerobot/scripts/lerobot_find_joint_limits.py b/src/lerobot/scripts/lerobot_find_joint_limits.py index b36cdbc90..20bbc8615 100644 --- a/src/lerobot/scripts/lerobot_find_joint_limits.py +++ b/src/lerobot/scripts/lerobot_find_joint_limits.py @@ -44,6 +44,7 @@ import numpy as np from lerobot.model.kinematics import RobotKinematics from lerobot.robots import ( # noqa: F401 RobotConfig, + bi_so_follower, koch_follower, make_robot_from_config, omx_follower, @@ -51,6 +52,7 @@ from lerobot.robots import ( # noqa: F401 ) from lerobot.teleoperators import ( # noqa: F401 TeleoperatorConfig, + bi_so_leader, gamepad, koch_leader, make_teleoperator_from_config, diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index a81a5d54e..8eafa8e6d 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -40,21 +40,23 @@ lerobot-record \ Example recording with bimanual so100: ```shell lerobot-record \ - --robot.type=bi_so100_follower \ - --robot.left_arm_port=/dev/tty.usbmodem5A460851411 \ - --robot.right_arm_port=/dev/tty.usbmodem5A460812391 \ + --robot.type=bi_so_follower \ + --robot.left_arm_config.port=/dev/tty.usbmodem5A460822851 \ + --robot.right_arm_config.port=/dev/tty.usbmodem5A460814411 \ --robot.id=bimanual_follower \ - --robot.cameras='{ - left: {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}, - top: {"type": "opencv", "index_or_path": 1, "width": 640, "height": 480, "fps": 30}, - right: {"type": "opencv", "index_or_path": 2, "width": 640, "height": 480, "fps": 30} + --robot.left_arm_config.cameras='{ + wrist: {"type": "opencv", "index_or_path": 1, "width": 640, "height": 480, "fps": 30}, + top: {"type": "opencv", "index_or_path": 3, "width": 640, "height": 480, "fps": 30}, + }' --robot.right_arm_config.cameras='{ + wrist: {"type": "opencv", "index_or_path": 2, "width": 640, "height": 480, "fps": 30}, + front: {"type": "opencv", "index_or_path": 4, "width": 640, "height": 480, "fps": 30}, }' \ - --teleop.type=bi_so100_leader \ - --teleop.left_arm_port=/dev/tty.usbmodem5A460828611 \ - --teleop.right_arm_port=/dev/tty.usbmodem5A460826981 \ + --teleop.type=bi_so_leader \ + --teleop.left_arm_config.port=/dev/tty.usbmodem5A460852721 \ + --teleop.right_arm_config.port=/dev/tty.usbmodem5A460819811 \ --teleop.id=bimanual_leader \ --display_data=true \ - --dataset.repo_id=${HF_USER}/bimanual-so100-handover-cube \ + --dataset.repo_id=${HF_USER}/bimanual-so-handover-cube \ --dataset.num_episodes=25 \ --dataset.single_task="Grab and handover the red cube to the other arm" ``` @@ -94,7 +96,7 @@ from lerobot.processor.rename_processor import rename_stats from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, - bi_so100_follower, + bi_so_follower, earthrover_mini_plus, hope_jr, koch_follower, @@ -105,7 +107,7 @@ from lerobot.robots import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401 Teleoperator, TeleoperatorConfig, - bi_so100_leader, + bi_so_leader, homunculus, koch_leader, make_teleoperator_from_config, diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index af7c63365..8e0d9cf6d 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -29,7 +29,7 @@ lerobot-replay \ Example replay with bimanual so100: ```shell lerobot-replay \ - --robot.type=bi_so100_follower \ + --robot.type=bi_so_follower \ --robot.left_arm_port=/dev/tty.usbmodem5A460851411 \ --robot.right_arm_port=/dev/tty.usbmodem5A460812391 \ --robot.id=bimanual_follower \ @@ -53,7 +53,7 @@ from lerobot.processor import ( from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, - bi_so100_follower, + bi_so_follower, earthrover_mini_plus, hope_jr, koch_follower, diff --git a/src/lerobot/scripts/lerobot_setup_motors.py b/src/lerobot/scripts/lerobot_setup_motors.py index ea5c821d1..01af95b61 100644 --- a/src/lerobot/scripts/lerobot_setup_motors.py +++ b/src/lerobot/scripts/lerobot_setup_motors.py @@ -30,6 +30,7 @@ import draccus from lerobot.robots import ( # noqa: F401 RobotConfig, + bi_so_follower, koch_follower, lekiwi, make_robot_from_config, @@ -38,6 +39,7 @@ from lerobot.robots import ( # noqa: F401 ) from lerobot.teleoperators import ( # noqa: F401 TeleoperatorConfig, + bi_so_leader, koch_leader, make_teleoperator_from_config, omx_leader, diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index 2e0724574..f2392bb51 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -33,18 +33,18 @@ Example teleoperation with bimanual so100: ```shell lerobot-teleoperate \ - --robot.type=bi_so100_follower \ - --robot.left_arm_port=/dev/tty.usbmodem5A460851411 \ - --robot.right_arm_port=/dev/tty.usbmodem5A460812391 \ + --robot.type=bi_so_follower \ + --robot.left_arm_config.port=/dev/tty.usbmodem5A460822851 \ + --robot.right_arm_config.port=/dev/tty.usbmodem5A460814411 \ --robot.id=bimanual_follower \ - --robot.cameras='{ - left: {"type": "opencv", "index_or_path": 0, "width": 1920, "height": 1080, "fps": 30}, - top: {"type": "opencv", "index_or_path": 1, "width": 1920, "height": 1080, "fps": 30}, - right: {"type": "opencv", "index_or_path": 2, "width": 1920, "height": 1080, "fps": 30} + --robot.left_arm_config.cameras='{ + wrist: {"type": "opencv", "index_or_path": 1, "width": 640, "height": 480, "fps": 30}, + }' --robot.right_arm_config.cameras='{ + wrist: {"type": "opencv", "index_or_path": 2, "width": 640, "height": 480, "fps": 30}, }' \ - --teleop.type=bi_so100_leader \ - --teleop.left_arm_port=/dev/tty.usbmodem5A460828611 \ - --teleop.right_arm_port=/dev/tty.usbmodem5A460826981 \ + --teleop.type=bi_so_leader \ + --teleop.left_arm_config.port=/dev/tty.usbmodem5A460852721 \ + --teleop.right_arm_config.port=/dev/tty.usbmodem5A460819811 \ --teleop.id=bimanual_leader \ --display_data=true ``` @@ -70,7 +70,7 @@ from lerobot.processor import ( from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, - bi_so100_follower, + bi_so_follower, earthrover_mini_plus, hope_jr, koch_follower, @@ -81,7 +81,7 @@ from lerobot.robots import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401 Teleoperator, TeleoperatorConfig, - bi_so100_leader, + bi_so_leader, gamepad, homunculus, keyboard, diff --git a/src/lerobot/robots/bi_so100_follower/__init__.py b/src/lerobot/teleoperators/bi_so_leader/__init__.py similarity index 76% rename from src/lerobot/robots/bi_so100_follower/__init__.py rename to src/lerobot/teleoperators/bi_so_leader/__init__.py index 90f56516b..b902270f9 100644 --- a/src/lerobot/robots/bi_so100_follower/__init__.py +++ b/src/lerobot/teleoperators/bi_so_leader/__init__.py @@ -1,6 +1,6 @@ #!/usr/bin/env python -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,5 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .bi_so100_follower import BiSO100Follower -from .config_bi_so100_follower import BiSO100FollowerConfig +from .bi_so_leader import BiSOLeader, BiSOLeaderConfig diff --git a/src/lerobot/teleoperators/bi_so100_leader/bi_so100_leader.py b/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py similarity index 62% rename from src/lerobot/teleoperators/bi_so100_leader/bi_so100_leader.py rename to src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py index 93f66eb2e..45c46c100 100644 --- a/src/lerobot/teleoperators/bi_so100_leader/bi_so100_leader.py +++ b/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py @@ -17,46 +17,50 @@ import logging from functools import cached_property -from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig +from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig +from ..so_leader import SOLeader from ..teleoperator import Teleoperator -from .config_bi_so100_leader import BiSO100LeaderConfig +from .config_bi_so_leader import BiSOLeaderConfig logger = logging.getLogger(__name__) -class BiSO100Leader(Teleoperator): +class BiSOLeader(Teleoperator): """ - [Bimanual SO-100 Leader Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio - This bimanual leader arm can also be easily adapted to use SO-101 leader arms, just replace the SO100Leader class with SO101Leader and SO100LeaderConfig with SO101LeaderConfig. + [Bimanual SO Leader Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio """ - config_class = BiSO100LeaderConfig - name = "bi_so100_leader" + config_class = BiSOLeaderConfig + name = "bi_so_leader" - def __init__(self, config: BiSO100LeaderConfig): + def __init__(self, config: BiSOLeaderConfig): super().__init__(config) self.config = config - left_arm_config = SO100LeaderConfig( + left_arm_config = SOLeaderTeleopConfig( id=f"{config.id}_left" if config.id else None, calibration_dir=config.calibration_dir, - port=config.left_arm_port, + port=config.left_arm_config.port, ) - right_arm_config = SO100LeaderConfig( + right_arm_config = SOLeaderTeleopConfig( id=f"{config.id}_right" if config.id else None, calibration_dir=config.calibration_dir, - port=config.right_arm_port, + port=config.right_arm_config.port, ) - self.left_arm = SO100Leader(left_arm_config) - self.right_arm = SO100Leader(right_arm_config) + self.left_arm = SOLeader(left_arm_config) + self.right_arm = SOLeader(right_arm_config) @cached_property def action_features(self) -> dict[str, type]: - return {f"left_{motor}.pos": float for motor in self.left_arm.bus.motors} | { - f"right_{motor}.pos": float for motor in self.right_arm.bus.motors + left_arm_features = self.left_arm.action_features + right_arm_features = self.right_arm.action_features + + return { + **{f"left_{k}": v for k, v in left_arm_features.items()}, + **{f"right_{k}": v for k, v in right_arm_features.items()}, } @cached_property @@ -101,19 +105,8 @@ class BiSO100Leader(Teleoperator): return action_dict def send_feedback(self, feedback: dict[str, float]) -> None: - # Remove "left_" prefix - left_feedback = { - key.removeprefix("left_"): value for key, value in feedback.items() if key.startswith("left_") - } - # Remove "right_" prefix - right_feedback = { - key.removeprefix("right_"): value for key, value in feedback.items() if key.startswith("right_") - } - - if left_feedback: - self.left_arm.send_feedback(left_feedback) - if right_feedback: - self.right_arm.send_feedback(right_feedback) + # TODO: Implement force feedback + raise NotImplementedError def disconnect(self) -> None: self.left_arm.disconnect() diff --git a/src/lerobot/teleoperators/bi_so100_leader/config_bi_so100_leader.py b/src/lerobot/teleoperators/bi_so_leader/config_bi_so_leader.py similarity index 71% rename from src/lerobot/teleoperators/bi_so100_leader/config_bi_so100_leader.py rename to src/lerobot/teleoperators/bi_so_leader/config_bi_so_leader.py index 117e09913..c2f23c617 100644 --- a/src/lerobot/teleoperators/bi_so100_leader/config_bi_so100_leader.py +++ b/src/lerobot/teleoperators/bi_so_leader/config_bi_so_leader.py @@ -16,11 +16,15 @@ from dataclasses import dataclass +from lerobot.teleoperators.so_leader import SOLeaderConfig + from ..config import TeleoperatorConfig -@TeleoperatorConfig.register_subclass("bi_so100_leader") +@TeleoperatorConfig.register_subclass("bi_so_leader") @dataclass -class BiSO100LeaderConfig(TeleoperatorConfig): - left_arm_port: str - right_arm_port: str +class BiSOLeaderConfig(TeleoperatorConfig): + """Configuration class for Bi SO Leader teleoperators.""" + + left_arm_config: SOLeaderConfig + right_arm_config: SOLeaderConfig diff --git a/src/lerobot/teleoperators/so_leader/__init__.py b/src/lerobot/teleoperators/so_leader/__init__.py index a1017f3b9..e5aaa31b6 100644 --- a/src/lerobot/teleoperators/so_leader/__init__.py +++ b/src/lerobot/teleoperators/so_leader/__init__.py @@ -14,9 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .so100_leader.config_so100_leader import SO100LeaderConfig -from .so100_leader.so100_leader import SO100Leader -from .so101_leader.config_so101_leader import SO101LeaderConfig -from .so101_leader.so101_leader import SO101Leader -from .so_leader_base import SOLeaderBase -from .so_leader_config_base import SOLeaderConfigBase +from .config_so_leader import ( + SO100LeaderConfig, + SO101LeaderConfig, + SOLeaderConfig, + SOLeaderTeleopConfig, +) +from .so_leader import SO100Leader, SO101Leader, SOLeader diff --git a/src/lerobot/teleoperators/so_leader/so_leader_config_base.py b/src/lerobot/teleoperators/so_leader/config_so_leader.py similarity index 72% rename from src/lerobot/teleoperators/so_leader/so_leader_config_base.py rename to src/lerobot/teleoperators/so_leader/config_so_leader.py index b2f2dfcfa..dd55196d7 100644 --- a/src/lerobot/teleoperators/so_leader/so_leader_config_base.py +++ b/src/lerobot/teleoperators/so_leader/config_so_leader.py @@ -15,12 +15,13 @@ # limitations under the License. from dataclasses import dataclass +from typing import TypeAlias from ..config import TeleoperatorConfig @dataclass -class SOLeaderConfigBase(TeleoperatorConfig): +class SOLeaderConfig: """Base configuration class for SO Leader teleoperators.""" # Port to connect to the arm @@ -28,3 +29,14 @@ class SOLeaderConfigBase(TeleoperatorConfig): # Whether to use degrees for angles use_degrees: bool = False + + +@TeleoperatorConfig.register_subclass("so101_leader") +@TeleoperatorConfig.register_subclass("so100_leader") +@dataclass +class SOLeaderTeleopConfig(TeleoperatorConfig, SOLeaderConfig): + pass + + +SO100LeaderConfig: TypeAlias = SOLeaderTeleopConfig +SO101LeaderConfig: TypeAlias = SOLeaderTeleopConfig diff --git a/src/lerobot/teleoperators/so_leader/so100.md b/src/lerobot/teleoperators/so_leader/so100.md new file mode 120000 index 000000000..ad1154e75 --- /dev/null +++ b/src/lerobot/teleoperators/so_leader/so100.md @@ -0,0 +1 @@ +../../../../docs/source/so100.mdx \ No newline at end of file diff --git a/src/lerobot/teleoperators/so_leader/so100_leader/config_so100_leader.py b/src/lerobot/teleoperators/so_leader/so100_leader/config_so100_leader.py deleted file mode 100644 index 092eb0cfc..000000000 --- a/src/lerobot/teleoperators/so_leader/so100_leader/config_so100_leader.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass - -from ...config import TeleoperatorConfig -from ..so_leader_config_base import SOLeaderConfigBase - - -@TeleoperatorConfig.register_subclass("so100_leader") -@dataclass -class SO100LeaderConfig(SOLeaderConfigBase): - pass diff --git a/src/lerobot/teleoperators/so_leader/so100_leader/so100.md b/src/lerobot/teleoperators/so_leader/so100_leader/so100.md deleted file mode 120000 index f06f88ff6..000000000 --- a/src/lerobot/teleoperators/so_leader/so100_leader/so100.md +++ /dev/null @@ -1 +0,0 @@ -../../../../../docs/source/so100.mdx \ No newline at end of file diff --git a/src/lerobot/teleoperators/so_leader/so100_leader/so100_leader.py b/src/lerobot/teleoperators/so_leader/so100_leader/so100_leader.py deleted file mode 100644 index 530e4f723..000000000 --- a/src/lerobot/teleoperators/so_leader/so100_leader/so100_leader.py +++ /dev/null @@ -1,27 +0,0 @@ -# !/usr/bin/env python - -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ..so_leader_base import SOLeaderBase -from .config_so100_leader import SO100LeaderConfig - - -class SO100Leader(SOLeaderBase): - """ - SO-101 leader robot class. [SO-101 Leader Arm](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio - """ - - config_class = SO100LeaderConfig - name = "so100_leader" diff --git a/src/lerobot/teleoperators/so_leader/so101.md b/src/lerobot/teleoperators/so_leader/so101.md new file mode 120000 index 000000000..27b892660 --- /dev/null +++ b/src/lerobot/teleoperators/so_leader/so101.md @@ -0,0 +1 @@ +../../../../docs/source/so101.mdx \ No newline at end of file diff --git a/src/lerobot/teleoperators/so_leader/so101_leader/so101.md b/src/lerobot/teleoperators/so_leader/so101_leader/so101.md deleted file mode 120000 index 38f4deca7..000000000 --- a/src/lerobot/teleoperators/so_leader/so101_leader/so101.md +++ /dev/null @@ -1 +0,0 @@ -../../../../../docs/source/so101.mdx \ No newline at end of file diff --git a/src/lerobot/teleoperators/so_leader/so101_leader/so101_leader.py b/src/lerobot/teleoperators/so_leader/so101_leader/so101_leader.py deleted file mode 100644 index 3aed1f6f9..000000000 --- a/src/lerobot/teleoperators/so_leader/so101_leader/so101_leader.py +++ /dev/null @@ -1,27 +0,0 @@ -# !/usr/bin/env python - -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ..so_leader_base import SOLeaderBase -from .config_so101_leader import SO101LeaderConfig - - -class SO101Leader(SOLeaderBase): - """ - SO-101 leader robot class. [SO-101 Leader Arm](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio - """ - - config_class = SO101LeaderConfig - name = "so101_leader" diff --git a/src/lerobot/teleoperators/so_leader/so_leader_base.py b/src/lerobot/teleoperators/so_leader/so_leader.py similarity index 95% rename from src/lerobot/teleoperators/so_leader/so_leader_base.py rename to src/lerobot/teleoperators/so_leader/so_leader.py index 6cdaca2e7..760ef2eb1 100644 --- a/src/lerobot/teleoperators/so_leader/so_leader_base.py +++ b/src/lerobot/teleoperators/so_leader/so_leader.py @@ -16,6 +16,7 @@ import logging import time +from typing import TypeAlias from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.feetech import ( @@ -25,15 +26,18 @@ from lerobot.motors.feetech import ( from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..teleoperator import Teleoperator -from .so_leader_config_base import SOLeaderConfigBase +from .config_so_leader import SOLeaderTeleopConfig logger = logging.getLogger(__name__) -class SOLeaderBase(Teleoperator): +class SOLeader(Teleoperator): """Generic SO leader base for SO-100/101/10X teleoperators.""" - def __init__(self, config: SOLeaderConfigBase): + config_class = SOLeaderTeleopConfig + name = "so_leader" + + def __init__(self, config: SOLeaderTeleopConfig): super().__init__(config) self.config = config norm_mode_body = MotorNormMode.DEGREES if config.use_degrees else MotorNormMode.RANGE_M100_100 @@ -153,3 +157,7 @@ class SOLeaderBase(Teleoperator): self.bus.disconnect() logger.info(f"{self} disconnected.") + + +SO100Leader: TypeAlias = SOLeader +SO101Leader: TypeAlias = SOLeader diff --git a/src/lerobot/teleoperators/utils.py b/src/lerobot/teleoperators/utils.py index 74e43ec95..eec2f119c 100644 --- a/src/lerobot/teleoperators/utils.py +++ b/src/lerobot/teleoperators/utils.py @@ -73,10 +73,10 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator: from .homunculus import HomunculusArm return HomunculusArm(config) - elif config.type == "bi_so100_leader": - from .bi_so100_leader import BiSO100Leader + elif config.type == "bi_so_leader": + from .bi_so_leader import BiSOLeader - return BiSO100Leader(config) + return BiSOLeader(config) elif config.type == "reachy2_teleoperator": from .reachy2_teleoperator import Reachy2Teleoperator diff --git a/tests/robots/test_so100_follower.py b/tests/robots/test_so100_follower.py index fc300b820..b61d0ca01 100644 --- a/tests/robots/test_so100_follower.py +++ b/tests/robots/test_so100_follower.py @@ -66,7 +66,7 @@ def follower(): with ( patch( - "lerobot.robots.so_follower.so_follower_base.FeetechMotorsBus", + "lerobot.robots.so_follower.so_follower.FeetechMotorsBus", side_effect=_bus_side_effect, ), patch.object(SO100Follower, "configure", lambda self: None), From 6b8d4c75a6e692a50991fd497a3b021f75499830 Mon Sep 17 00:00:00 2001 From: Martino Russi <77496684+nepyope@users.noreply.github.com> Date: Mon, 12 Jan 2026 17:31:39 +0100 Subject: [PATCH 032/111] Feat/g1 improvements record sim (#2765) This PR extends the integration of Unitree g1 with the LeRobot codebase. By converting robot state to a flat dict we can now record and replay episodes (example groot/holosoma scripts need to be adjusted as well). We also improve the simulation integration by calling .step @ _subscribe_motor_state instead of it running in a separate thread. We also add ZMQ camera to lerobot, streaming base64 images over json --- docs/source/unitree_g1.mdx | 6 +- examples/unitree_g1/gr00t_locomotion.py | 47 ++- examples/unitree_g1/holosoma_locomotion.py | 28 +- src/lerobot/cameras/utils.py | 5 + src/lerobot/cameras/zmq/__init__.py | 20 ++ src/lerobot/cameras/zmq/camera_zmq.py | 235 +++++++++++++ src/lerobot/cameras/zmq/configuration_zmq.py | 46 +++ src/lerobot/cameras/zmq/image_server.py | 114 +++++++ .../robots/unitree_g1/config_unitree_g1.py | 5 + src/lerobot/robots/unitree_g1/unitree_g1.py | 314 ++++++++++++------ src/lerobot/scripts/lerobot_record.py | 7 + src/lerobot/scripts/lerobot_replay.py | 1 + 12 files changed, 679 insertions(+), 149 deletions(-) create mode 100644 src/lerobot/cameras/zmq/__init__.py create mode 100644 src/lerobot/cameras/zmq/camera_zmq.py create mode 100644 src/lerobot/cameras/zmq/configuration_zmq.py create mode 100644 src/lerobot/cameras/zmq/image_server.py diff --git a/docs/source/unitree_g1.mdx b/docs/source/unitree_g1.mdx index bdc7eb33d..e6bffdf1b 100644 --- a/docs/source/unitree_g1.mdx +++ b/docs/source/unitree_g1.mdx @@ -7,7 +7,7 @@ This guide covers the complete setup process for the Unitree G1 humanoid, from i We support both 29 and 23 DOF G1 EDU version. We introduce: - **`unitree g1` robot class, handling low level read/write from/to the humanoid** -- **ZMQ socket bridge** for remote communication over wlan, allowing for remote policy deployment as well as over eth or directly on the Orin +- **ZMQ socket bridge** for remote communication and camera streaming, allowing for remote policy deployment over wlan, eth or directly on the robot - **Locomotion policies** from NVIDIA gr00t and Amazon FAR Holosoma - **Simulation mode** for testing policies without the physical robot in mujoco @@ -110,7 +110,7 @@ ssh unitree@ # Password: 123 ``` -Replace `` with your robot's actual WiFi IP address (e.g., `172.18.129.215`). +Replace `` with your robot's actual WiFi IP address. --- @@ -188,7 +188,7 @@ Press `Ctrl+C` to stop the policy. ## Running in Simulation Mode (MuJoCo) -You can now test and develop policies without a physical robot using MuJoCo. To do so simply set `is_simulation=True` in config. +You can now test policies before unleashing them on the physical robot using MuJoCo. To do so simply set `is_simulation=True` in config. ## Additional Resources diff --git a/examples/unitree_g1/gr00t_locomotion.py b/examples/unitree_g1/gr00t_locomotion.py index 30e5a27e6..0123b5206 100644 --- a/examples/unitree_g1/gr00t_locomotion.py +++ b/examples/unitree_g1/gr00t_locomotion.py @@ -111,34 +111,29 @@ class GrootLocomotionController: def run_step(self): # Get current observation - robot_state = self.robot.get_observation() + obs = self.robot.get_observation() - if robot_state is None: + if not obs: return # Get command from remote controller - if robot_state.wireless_remote is not None: - self.robot.remote_controller.set(robot_state.wireless_remote) - if self.robot.remote_controller.button[0]: # R1 - raise waist - self.groot_height_cmd += 0.001 - self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00) - if self.robot.remote_controller.button[4]: # R2 - lower waist - self.groot_height_cmd -= 0.001 - self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00) - else: - self.robot.remote_controller.lx = 0.0 - self.robot.remote_controller.ly = 0.0 - self.robot.remote_controller.rx = 0.0 - self.robot.remote_controller.ry = 0.0 + if obs["remote.buttons"][0]: # R1 - raise waist + self.groot_height_cmd += 0.001 + self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00) + if obs["remote.buttons"][4]: # R2 - lower waist + self.groot_height_cmd -= 0.001 + self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00) - self.cmd[0] = self.robot.remote_controller.ly # Forward/backward - self.cmd[1] = self.robot.remote_controller.lx * -1 # Left/right - self.cmd[2] = self.robot.remote_controller.rx * -1 # Rotation rate + self.cmd[0] = obs["remote.ly"] # Forward/backward + self.cmd[1] = obs["remote.lx"] * -1 # Left/right + self.cmd[2] = obs["remote.rx"] * -1 # Rotation rate - # Get joint positions and velocities - for i in range(29): - self.groot_qj_all[i] = robot_state.motor_state[i].q - self.groot_dqj_all[i] = robot_state.motor_state[i].dq + # Get joint positions and velocities from flat dict + for motor in G1_29_JointIndex: + name = motor.name + idx = motor.value + self.groot_qj_all[idx] = obs[f"{name}.q"] + self.groot_dqj_all[idx] = obs[f"{name}.dq"] # Adapt observation for g1_23dof for idx in MISSING_JOINTS: @@ -150,8 +145,8 @@ class GrootLocomotionController: dqj_obs = self.groot_dqj_all.copy() # Express IMU data in gravity frame of reference - quat = robot_state.imu_state.quaternion - ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32) + quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]] + ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32) gravity_orientation = self.robot.get_gravity_orientation(quat) # Scale joint positions and velocities before policy inference @@ -219,6 +214,8 @@ def run(repo_id: str = DEFAULT_GROOT_REPO_ID) -> None: config = UnitreeG1Config() robot = UnitreeG1(config) + robot.connect() + # Initialize gr00T locomotion controller groot_controller = GrootLocomotionController( policy_balance=policy_balance, @@ -234,7 +231,7 @@ def run(repo_id: str = DEFAULT_GROOT_REPO_ID) -> None: logger.info("Press Ctrl+C to stop") # Run step - while True: + while not robot._shutdown_event.is_set(): start_time = time.time() groot_controller.run_step() elapsed = time.time() - start_time diff --git a/examples/unitree_g1/holosoma_locomotion.py b/examples/unitree_g1/holosoma_locomotion.py index 017f7835a..3a07023de 100644 --- a/examples/unitree_g1/holosoma_locomotion.py +++ b/examples/unitree_g1/holosoma_locomotion.py @@ -126,24 +126,23 @@ class HolosomaLocomotionController: def run_step(self): # Get current observation - robot_state = self.robot.get_observation() + obs = self.robot.get_observation() - if robot_state is None: + if not obs: return # Get command from remote controller - if robot_state.wireless_remote is not None: - self.robot.remote_controller.set(robot_state.wireless_remote) - - ly = self.robot.remote_controller.ly if abs(self.robot.remote_controller.ly) > 0.1 else 0.0 - lx = self.robot.remote_controller.lx if abs(self.robot.remote_controller.lx) > 0.1 else 0.0 - rx = self.robot.remote_controller.rx if abs(self.robot.remote_controller.rx) > 0.1 else 0.0 + ly = obs["remote.ly"] if abs(obs["remote.ly"]) > 0.1 else 0.0 + lx = obs["remote.lx"] if abs(obs["remote.lx"]) > 0.1 else 0.0 + rx = obs["remote.rx"] if abs(obs["remote.rx"]) > 0.1 else 0.0 self.cmd[:] = [ly, -lx, -rx] # Get joint positions and velocities - for i in range(29): - self.qj[i] = robot_state.motor_state[i].q - self.dqj[i] = robot_state.motor_state[i].dq + for motor in G1_29_JointIndex: + name = motor.name + idx = motor.value + self.qj[idx] = obs[f"{name}.q"] + self.dqj[idx] = obs[f"{name}.dq"] # Adapt observation for g1_23dof for idx in MISSING_JOINTS: @@ -151,8 +150,8 @@ class HolosomaLocomotionController: self.dqj[idx] = 0.0 # Express IMU data in gravity frame of reference - quat = robot_state.imu_state.quaternion - ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32) + quat = [obs["imu.quat.w"], obs["imu.quat.x"], obs["imu.quat.y"], obs["imu.quat.z"]] + ang_vel = np.array([obs["imu.gyro.x"], obs["imu.gyro.y"], obs["imu.gyro.z"]], dtype=np.float32) gravity = self.robot.get_gravity_orientation(quat) # Scale joint positions and velocities before policy inference @@ -220,6 +219,7 @@ def run(repo_id: str = DEFAULT_HOLOSOMA_REPO_ID, policy_type: str = "fastsac") - # Initialize robot config = UnitreeG1Config() robot = UnitreeG1(config) + robot.connect() holosoma_controller = HolosomaLocomotionController(policy, robot, kp, kd) @@ -230,7 +230,7 @@ def run(repo_id: str = DEFAULT_HOLOSOMA_REPO_ID, policy_type: str = "fastsac") - logger.info("Press Ctrl+C to stop") # Run step - while True: + while not robot._shutdown_event.is_set(): start_time = time.time() holosoma_controller.run_step() elapsed = time.time() - start_time diff --git a/src/lerobot/cameras/utils.py b/src/lerobot/cameras/utils.py index 1b2d386d6..c0e7b6284 100644 --- a/src/lerobot/cameras/utils.py +++ b/src/lerobot/cameras/utils.py @@ -43,6 +43,11 @@ def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[s cameras[key] = Reachy2Camera(cfg) + elif cfg.type == "zmq": + from .zmq.camera_zmq import ZMQCamera + + cameras[key] = ZMQCamera(cfg) + else: try: cameras[key] = cast(Camera, make_device_from_device_class(cfg)) diff --git a/src/lerobot/cameras/zmq/__init__.py b/src/lerobot/cameras/zmq/__init__.py new file mode 100644 index 000000000..d760c5325 --- /dev/null +++ b/src/lerobot/cameras/zmq/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .camera_zmq import ZMQCamera +from .configuration_zmq import ZMQCameraConfig + +__all__ = ["ZMQCamera", "ZMQCameraConfig"] diff --git a/src/lerobot/cameras/zmq/camera_zmq.py b/src/lerobot/cameras/zmq/camera_zmq.py new file mode 100644 index 000000000..1a4155f4b --- /dev/null +++ b/src/lerobot/cameras/zmq/camera_zmq.py @@ -0,0 +1,235 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +ZMQCamera - Captures frames from remote cameras via ZeroMQ using JSON protocol in the +following format: + { + "timestamps": {"camera_name": float}, + "images": {"camera_name": ""} + } +""" + +import base64 +import json +import logging +import time +from threading import Event, Lock, Thread +from typing import Any + +import cv2 +import numpy as np +from numpy.typing import NDArray + +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +from ..camera import Camera +from ..configs import ColorMode +from .configuration_zmq import ZMQCameraConfig + +logger = logging.getLogger(__name__) + + +class ZMQCamera(Camera): + """ + Example usage: + ```python + from lerobot.cameras.zmq import ZMQCamera, ZMQCameraConfig + + config = ZMQCameraConfig(server_address="192.168.123.164", port=5555, camera_name="head_camera") + camera = ZMQCamera(config) + camera.connect() + frame = camera.read() + camera.disconnect() + ``` + """ + + def __init__(self, config: ZMQCameraConfig): + super().__init__(config) + import zmq + + self.config = config + self.server_address = config.server_address + self.port = config.port + self.camera_name = config.camera_name + self.color_mode = config.color_mode + self.timeout_ms = config.timeout_ms + + self.context: zmq.Context | None = None + self.socket: zmq.Socket | None = None + self._connected = False + + self.thread: Thread | None = None + self.stop_event: Event | None = None + self.frame_lock: Lock = Lock() + self.latest_frame: NDArray[Any] | None = None + self.new_frame_event: Event = Event() + + def __str__(self) -> str: + return f"ZMQCamera({self.camera_name}@{self.server_address}:{self.port})" + + @property + def is_connected(self) -> bool: + return self._connected and self.context is not None and self.socket is not None + + def connect(self, warmup: bool = True) -> None: + """Connect to ZMQ camera server.""" + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} is already connected.") + + logger.info(f"Connecting to {self}...") + + try: + import zmq + + self.context = zmq.Context() + self.socket = self.context.socket(zmq.SUB) + self.socket.setsockopt_string(zmq.SUBSCRIBE, "") + self.socket.setsockopt(zmq.RCVTIMEO, self.timeout_ms) + self.socket.setsockopt(zmq.CONFLATE, True) + self.socket.connect(f"tcp://{self.server_address}:{self.port}") + self._connected = True + + # Auto-detect resolution + if self.width is None or self.height is None: + h, w = self.read().shape[:2] + self.height = h + self.width = w + logger.info(f"{self} resolution: {w}x{h}") + + logger.info(f"{self} connected.") + + if warmup: + time.sleep(0.1) + + except Exception as e: + self._cleanup() + raise RuntimeError(f"Failed to connect to {self}: {e}") from e + + def _cleanup(self): + """Clean up ZMQ resources.""" + self._connected = False + if self.socket: + self.socket.close() + self.socket = None + if self.context: + self.context.term() + self.context = None + + @staticmethod + def find_cameras() -> list[dict[str, Any]]: + """ZMQ cameras require manual configuration (server address/port).""" + return [] + + def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: + """ + Read a single frame from the ZMQ camera. + + Returns: + np.ndarray: Decoded frame (height, width, 3) + """ + if not self.is_connected or self.socket is None: + raise DeviceNotConnectedError(f"{self} is not connected.") + + try: + message = self.socket.recv_string() + except Exception as e: + if type(e).__name__ == "Again": + raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e + raise + + # Decode JSON message + data = json.loads(message) + + if "images" not in data: + raise RuntimeError(f"{self} invalid message: missing 'images' key") + + images = data["images"] + + # Get image by camera name or first available + if self.camera_name in images: + img_b64 = images[self.camera_name] + elif images: + img_b64 = next(iter(images.values())) + else: + raise RuntimeError(f"{self} no images in message") + + # Decode base64 JPEG + img_bytes = base64.b64decode(img_b64) + frame = cv2.imdecode(np.frombuffer(img_bytes, np.uint8), cv2.IMREAD_COLOR) + + if frame is None: + raise RuntimeError(f"{self} failed to decode image") + + return frame + + def _read_loop(self) -> None: + while self.stop_event and not self.stop_event.is_set(): + try: + frame = self.read() + with self.frame_lock: + self.latest_frame = frame + self.new_frame_event.set() + except DeviceNotConnectedError: + break + except TimeoutError: + pass + except Exception as e: + logger.warning(f"Read error: {e}") + + def _start_read_thread(self) -> None: + if self.thread and self.thread.is_alive(): + return + self.stop_event = Event() + self.thread = Thread(target=self._read_loop, daemon=True) + self.thread.start() + + def _stop_read_thread(self) -> None: + if self.stop_event: + self.stop_event.set() + if self.thread and self.thread.is_alive(): + self.thread.join(timeout=2.0) + self.thread = None + self.stop_event = None + + def async_read(self, timeout_ms: float = 10000) -> NDArray[Any]: + """Read latest frame asynchronously (non-blocking).""" + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if not self.thread or not self.thread.is_alive(): + self._start_read_thread() + + if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0): + raise TimeoutError(f"{self} async_read timeout after {timeout_ms}ms") + + with self.frame_lock: + frame = self.latest_frame + self.new_frame_event.clear() + + if frame is None: + raise RuntimeError(f"{self} no frame available") + + return frame + + def disconnect(self) -> None: + """Disconnect from ZMQ camera.""" + if not self.is_connected and not self.thread: + raise DeviceNotConnectedError(f"{self} not connected.") + + self._stop_read_thread() + self._cleanup() + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/cameras/zmq/configuration_zmq.py b/src/lerobot/cameras/zmq/configuration_zmq.py new file mode 100644 index 000000000..027ae12b5 --- /dev/null +++ b/src/lerobot/cameras/zmq/configuration_zmq.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from ..configs import CameraConfig, ColorMode + +__all__ = ["ZMQCameraConfig", "ColorMode"] + + +@CameraConfig.register_subclass("zmq") +@dataclass +class ZMQCameraConfig(CameraConfig): + server_address: str + port: int = 5555 + camera_name: str = "zmq_camera" + color_mode: ColorMode = ColorMode.RGB + timeout_ms: int = 5000 + + def __post_init__(self) -> None: + if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): + raise ValueError( + f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided." + ) + + if self.timeout_ms <= 0: + raise ValueError(f"`timeout_ms` must be positive, but {self.timeout_ms} is provided.") + + if not self.server_address: + raise ValueError("`server_address` cannot be empty.") + + if self.port <= 0 or self.port > 65535: + raise ValueError(f"`port` must be between 1 and 65535, but {self.port} is provided.") diff --git a/src/lerobot/cameras/zmq/image_server.py b/src/lerobot/cameras/zmq/image_server.py new file mode 100644 index 000000000..2da366cef --- /dev/null +++ b/src/lerobot/cameras/zmq/image_server.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Streams camera images over ZMQ. +Uses lerobot's OpenCVCamera for capture, encodes images to base64 and sends them over ZMQ. +""" + +import base64 +import contextlib +import json +import logging +import time +from collections import deque + +import cv2 +import numpy as np +import zmq + +from lerobot.cameras.configs import ColorMode +from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig + +logger = logging.getLogger(__name__) + + +def encode_image(image: np.ndarray, quality: int = 80) -> str: + """Encode RGB image to base64 JPEG string.""" + _, buffer = cv2.imencode(".jpg", image, [int(cv2.IMWRITE_JPEG_QUALITY), quality]) + return base64.b64encode(buffer).decode("utf-8") + + +class ImageServer: + def __init__(self, config: dict, port: int = 5555): + self.fps = config.get("fps", 30) + self.cameras: dict[str, OpenCVCamera] = {} + + for name, cfg in config.get("cameras", {}).items(): + shape = cfg.get("shape", [480, 640]) + cam_config = OpenCVCameraConfig( + index_or_path=cfg.get("device_id", 0), + fps=self.fps, + width=shape[1], + height=shape[0], + color_mode=ColorMode.RGB, + ) + camera = OpenCVCamera(cam_config) + camera.connect() + self.cameras[name] = camera + logger.info(f"Camera {name}: {shape[1]}x{shape[0]}") + + # ZMQ PUB socket + self.context = zmq.Context() + self.socket = self.context.socket(zmq.PUB) + self.socket.setsockopt(zmq.SNDHWM, 20) + self.socket.setsockopt(zmq.LINGER, 0) + self.socket.bind(f"tcp://*:{port}") + + logger.info(f"ImageServer running on port {port}") + + def run(self): + frame_count = 0 + frame_times = deque(maxlen=60) + + try: + while True: + t0 = time.time() + + # Build message + message = {"timestamps": {}, "images": {}} + for name, cam in self.cameras.items(): + frame = cam.read() # Returns RGB + message["timestamps"][name] = time.time() + message["images"][name] = encode_image(frame) + + # Send as JSON string (suppress if buffer full) + with contextlib.suppress(zmq.Again): + self.socket.send_string(json.dumps(message), zmq.NOBLOCK) + + frame_count += 1 + frame_times.append(time.time() - t0) + + if frame_count % 60 == 0: + logger.debug(f"FPS: {len(frame_times) / sum(frame_times):.1f}") + + sleep = (1.0 / self.fps) - (time.time() - t0) + if sleep > 0: + time.sleep(sleep) + + except KeyboardInterrupt: + pass + finally: + for cam in self.cameras.values(): + cam.disconnect() + self.socket.close() + self.context.term() + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + config = {"fps": 30, "cameras": {"head_camera": {"device_id": 4, "shape": [480, 640]}}} + ImageServer(config, port=5555).run() diff --git a/src/lerobot/robots/unitree_g1/config_unitree_g1.py b/src/lerobot/robots/unitree_g1/config_unitree_g1.py index c66edbd1c..0b163019d 100644 --- a/src/lerobot/robots/unitree_g1/config_unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/config_unitree_g1.py @@ -16,6 +16,8 @@ from dataclasses import dataclass, field +from lerobot.cameras import CameraConfig + from ..config import RobotConfig _GAINS: dict[str, dict[str, list[float]]] = { @@ -60,3 +62,6 @@ class UnitreeG1Config(RobotConfig): # Socket config for ZMQ bridge robot_ip: str = "192.168.123.164" # default G1 IP + + # Cameras (ZMQ-based remote cameras) + cameras: dict[str, CameraConfig] = field(default_factory=dict) diff --git a/src/lerobot/robots/unitree_g1/unitree_g1.py b/src/lerobot/robots/unitree_g1/unitree_g1.py index 5bc4f3110..1764f31b5 100644 --- a/src/lerobot/robots/unitree_g1/unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/unitree_g1.py @@ -23,13 +23,8 @@ from functools import cached_property from typing import Any import numpy as np -from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_ -from unitree_sdk2py.idl.unitree_hg.msg.dds_ import ( - LowCmd_ as hg_LowCmd, - LowState_ as hg_LowState, -) -from unitree_sdk2py.utils.crc import CRC +from lerobot.cameras.utils import make_cameras_from_configs from lerobot.envs.factory import make_env from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex @@ -43,8 +38,6 @@ logger = logging.getLogger(__name__) kTopicLowCommand_Debug = "rt/lowcmd" kTopicLowState = "rt/lowstate" -G1_29_Num_Motors = 29 - @dataclass class MotorState: @@ -66,28 +59,12 @@ class IMUState: # g1 observation class @dataclass class G1_29_LowState: # noqa: N801 - motor_state: list[MotorState] = field( - default_factory=lambda: [MotorState() for _ in range(G1_29_Num_Motors)] - ) + motor_state: list[MotorState] = field(default_factory=lambda: [MotorState() for _ in G1_29_JointIndex]) imu_state: IMUState = field(default_factory=IMUState) wireless_remote: Any = None # Raw wireless remote data mode_machine: int = 0 # Robot mode -class DataBuffer: - def __init__(self): - self.data = None - self.lock = threading.Lock() - - def get_data(self): - with self.lock: - return self.data - - def set_data(self, data): - with self.lock: - self.data = data - - class UnitreeG1(Robot): config_class = UnitreeG1Config name = "unitree_g1" @@ -117,9 +94,12 @@ class UnitreeG1(Robot): logger.info("Initialize UnitreeG1...") self.config = config - self.control_dt = config.control_dt + # Initialize cameras config (ZMQ-based) - actual connection in connect() + self._cameras = make_cameras_from_configs(config.cameras) + + # Import channel classes based on mode if config.is_simulation: from unitree_sdk2py.core.channel import ( ChannelFactoryInitialize, @@ -133,62 +113,33 @@ class UnitreeG1(Robot): ChannelSubscriber, ) - # connect robot - self.ChannelFactoryInitialize = ChannelFactoryInitialize - self.connect() + # Store for use in connect() + self._ChannelFactoryInitialize = ChannelFactoryInitialize + self._ChannelPublisher = ChannelPublisher + self._ChannelSubscriber = ChannelSubscriber - # initialize direct motor control interface - self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd) - self.lowcmd_publisher.Init() - self.lowstate_subscriber = ChannelSubscriber(kTopicLowState, hg_LowState) - self.lowstate_subscriber.Init() - self.lowstate_buffer = DataBuffer() - - # initialize subscribe thread to read robot state + # Initialize state variables + self.sim_env = None + self._env_wrapper = None + self._lowstate = None self._shutdown_event = threading.Event() - self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state) - self.subscribe_thread.start() - - while not self.is_connected: - time.sleep(0.1) - - # initialize hg's lowcmd msg - self.crc = CRC() - self.msg = unitree_hg_msg_dds__LowCmd_() - self.msg.mode_pr = 0 - - # Wait for first state message to arrive - lowstate = None - while lowstate is None: - lowstate = self.lowstate_buffer.get_data() - if lowstate is None: - time.sleep(0.01) - logger.warning("[UnitreeG1] Waiting for robot state...") - logger.warning("[UnitreeG1] Connected to robot.") - self.msg.mode_machine = lowstate.mode_machine - - # initialize all motors with unified kp/kd from config - self.kp = np.array(config.kp, dtype=np.float32) - self.kd = np.array(config.kd, dtype=np.float32) - - for id in G1_29_JointIndex: - self.msg.motor_cmd[id].mode = 1 - self.msg.motor_cmd[id].kp = self.kp[id.value] - self.msg.motor_cmd[id].kd = self.kd[id.value] - self.msg.motor_cmd[id].q = lowstate.motor_state[id.value].q - - # Initialize remote controller + self.subscribe_thread = None self.remote_controller = self.RemoteController() def _subscribe_motor_state(self): # polls robot state @ 250Hz while not self._shutdown_event.is_set(): start_time = time.time() + + # Step simulation if in simulation mode + if self.config.is_simulation and self.sim_env is not None: + self.sim_env.step() + msg = self.lowstate_subscriber.Read() if msg is not None: lowstate = G1_29_LowState() - # Capture motor states - for id in range(G1_29_Num_Motors): + # Capture motor states using jointindex + for id in G1_29_JointIndex: lowstate.motor_state[id].q = msg.motor_state[id].q lowstate.motor_state[id].dq = msg.motor_state[id].dq lowstate.motor_state[id].tau_est = msg.motor_state[id].tau_est @@ -207,7 +158,7 @@ class UnitreeG1(Robot): # Capture mode_machine lowstate.mode_machine = msg.mode_machine - self.lowstate_buffer.set_data(lowstate) + self._lowstate = lowstate current_time = time.time() all_t_elapsed = current_time - start_time @@ -216,7 +167,7 @@ class UnitreeG1(Robot): @cached_property def action_features(self) -> dict[str, type]: - return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex} + return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex} def calibrate(self) -> None: # robot is already calibrated pass @@ -225,20 +176,153 @@ class UnitreeG1(Robot): pass def connect(self, calibrate: bool = True) -> None: # connect to DDS + from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_ + from unitree_sdk2py.idl.unitree_hg.msg.dds_ import ( + LowCmd_ as hg_LowCmd, + LowState_ as hg_LowState, + ) + from unitree_sdk2py.utils.crc import CRC + + # Initialize DDS channel and simulation environment if self.config.is_simulation: - self.ChannelFactoryInitialize(0, "lo") - self.mujoco_env = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True) + self._ChannelFactoryInitialize(0, "lo") + self._env_wrapper = make_env("lerobot/unitree-g1-mujoco", trust_remote_code=True) + # Extract the actual gym env from the dict structure + self.sim_env = self._env_wrapper["hub_env"][0].envs[0] else: - self.ChannelFactoryInitialize(0) + self._ChannelFactoryInitialize(0) + + # Initialize direct motor control interface + self.lowcmd_publisher = self._ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd) + self.lowcmd_publisher.Init() + self.lowstate_subscriber = self._ChannelSubscriber(kTopicLowState, hg_LowState) + self.lowstate_subscriber.Init() + + # Start subscribe thread to read robot state + self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state) + self.subscribe_thread.start() + + # Connect cameras + for cam in self._cameras.values(): + if not cam.is_connected: + cam.connect() + + logger.info(f"Connected {len(self._cameras)} camera(s).") + + # Initialize lowcmd message + self.crc = CRC() + self.msg = unitree_hg_msg_dds__LowCmd_() + self.msg.mode_pr = 0 + + # Wait for first state message to arrive + lowstate = None + while lowstate is None: + lowstate = self._lowstate + if lowstate is None: + time.sleep(0.01) + logger.warning("[UnitreeG1] Waiting for robot state...") + logger.warning("[UnitreeG1] Connected to robot.") + self.msg.mode_machine = lowstate.mode_machine + + # Initialize all motors with unified kp/kd from config + self.kp = np.array(self.config.kp, dtype=np.float32) + self.kd = np.array(self.config.kd, dtype=np.float32) + + for id in G1_29_JointIndex: + self.msg.motor_cmd[id].mode = 1 + self.msg.motor_cmd[id].kp = self.kp[id.value] + self.msg.motor_cmd[id].kd = self.kd[id.value] + self.msg.motor_cmd[id].q = lowstate.motor_state[id.value].q def disconnect(self): + # Signal thread to stop and unblock any waits self._shutdown_event.set() - self.subscribe_thread.join(timeout=2.0) - if self.config.is_simulation: - self.mujoco_env["hub_env"][0].envs[0].kill_sim() + + # Wait for subscribe thread to finish + if self.subscribe_thread is not None: + self.subscribe_thread.join(timeout=2.0) + if self.subscribe_thread.is_alive(): + logger.warning("Subscribe thread did not stop cleanly") + + # Close simulation environment + if self.config.is_simulation and self.sim_env is not None: + try: + # Force-kill the image publish subprocess first to avoid long waits + if hasattr(self.sim_env, "simulator") and hasattr(self.sim_env.simulator, "sim_env"): + sim_env_inner = self.sim_env.simulator.sim_env + if hasattr(sim_env_inner, "image_publish_process"): + proc = sim_env_inner.image_publish_process + if proc.process and proc.process.is_alive(): + logger.info("Force-terminating image publish subprocess...") + proc.stop_event.set() + proc.process.terminate() + proc.process.join(timeout=1) + if proc.process.is_alive(): + proc.process.kill() + self.sim_env.close() + except Exception as e: + logger.warning(f"Error closing sim_env: {e}") + self.sim_env = None + self._env_wrapper = None + + # Disconnect cameras + for cam in self._cameras.values(): + cam.disconnect() def get_observation(self) -> dict[str, Any]: - return self.lowstate_buffer.get_data() + lowstate = self._lowstate + if lowstate is None: + return {} + + obs = {} + + # Motors - q, dq, tau for all joints + for motor in G1_29_JointIndex: + name = motor.name + idx = motor.value + obs[f"{name}.q"] = lowstate.motor_state[idx].q + obs[f"{name}.dq"] = lowstate.motor_state[idx].dq + obs[f"{name}.tau"] = lowstate.motor_state[idx].tau_est + + # IMU - gyroscope + if lowstate.imu_state.gyroscope: + obs["imu.gyro.x"] = lowstate.imu_state.gyroscope[0] + obs["imu.gyro.y"] = lowstate.imu_state.gyroscope[1] + obs["imu.gyro.z"] = lowstate.imu_state.gyroscope[2] + + # IMU - accelerometer + if lowstate.imu_state.accelerometer: + obs["imu.accel.x"] = lowstate.imu_state.accelerometer[0] + obs["imu.accel.y"] = lowstate.imu_state.accelerometer[1] + obs["imu.accel.z"] = lowstate.imu_state.accelerometer[2] + + # IMU - quaternion + if lowstate.imu_state.quaternion: + obs["imu.quat.w"] = lowstate.imu_state.quaternion[0] + obs["imu.quat.x"] = lowstate.imu_state.quaternion[1] + obs["imu.quat.y"] = lowstate.imu_state.quaternion[2] + obs["imu.quat.z"] = lowstate.imu_state.quaternion[3] + + # IMU - rpy + if lowstate.imu_state.rpy: + obs["imu.rpy.roll"] = lowstate.imu_state.rpy[0] + obs["imu.rpy.pitch"] = lowstate.imu_state.rpy[1] + obs["imu.rpy.yaw"] = lowstate.imu_state.rpy[2] + + # Controller - parse wireless_remote and add to obs + if lowstate.wireless_remote and len(lowstate.wireless_remote) >= 24: + self.remote_controller.set(lowstate.wireless_remote) + obs["remote.buttons"] = self.remote_controller.button.copy() + obs["remote.lx"] = self.remote_controller.lx + obs["remote.ly"] = self.remote_controller.ly + obs["remote.rx"] = self.remote_controller.rx + obs["remote.ry"] = self.remote_controller.ry + + # Cameras - read images from ZMQ cameras + for cam_name, cam in self._cameras.items(): + obs[cam_name] = cam.async_read() + + return obs @property def is_calibrated(self) -> bool: @@ -246,11 +330,15 @@ class UnitreeG1(Robot): @property def is_connected(self) -> bool: - return self.lowstate_buffer.get_data() is not None + return self._lowstate is not None @property def _motors_ft(self) -> dict[str, type]: - return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex} + return {f"{G1_29_JointIndex(motor).name}.q": float for motor in G1_29_JointIndex} + + @property + def cameras(self) -> dict: + return self._cameras @property def _cameras_ft(self) -> dict[str, tuple]: @@ -293,39 +381,51 @@ class UnitreeG1(Robot): self, control_dt: float | None = None, default_positions: list[float] | None = None, - ) -> None: # interpolate to default position + ) -> None: # move robot to default position if control_dt is None: control_dt = self.config.control_dt if default_positions is None: default_positions = np.array(self.config.default_positions, dtype=np.float32) - total_time = 3.0 - num_steps = int(total_time / control_dt) + if self.config.is_simulation and self.sim_env is not None: + self.sim_env.reset() - # get current state - robot_state = self.get_observation() - - # record current positions - init_dof_pos = np.zeros(29, dtype=np.float32) - for i in range(29): - init_dof_pos[i] = robot_state.motor_state[i].q - - # Interpolate to default position - for step in range(num_steps): - start_time = time.time() - - alpha = step / num_steps - action_dict = {} for motor in G1_29_JointIndex: - target_pos = default_positions[motor.value] - interp_pos = init_dof_pos[motor.value] * (1 - alpha) + target_pos * alpha - action_dict[f"{motor.name}.q"] = float(interp_pos) + self.msg.motor_cmd[motor.value].q = default_positions[motor.value] + self.msg.motor_cmd[motor.value].qd = 0 + self.msg.motor_cmd[motor.value].kp = self.kp[motor.value] + self.msg.motor_cmd[motor.value].kd = self.kd[motor.value] + self.msg.motor_cmd[motor.value].tau = 0 + self.msg.crc = self.crc.Crc(self.msg) + self.lowcmd_publisher.Write(self.msg) + else: + total_time = 3.0 + num_steps = int(total_time / control_dt) - self.send_action(action_dict) + # get current state + obs = self.get_observation() - # Maintain constant control rate - elapsed = time.time() - start_time - sleep_time = max(0, control_dt - elapsed) - time.sleep(sleep_time) + # record current positions + init_dof_pos = np.zeros(29, dtype=np.float32) + for motor in G1_29_JointIndex: + init_dof_pos[motor.value] = obs[f"{motor.name}.q"] + + # Interpolate to default position + for step in range(num_steps): + start_time = time.time() + + alpha = step / num_steps + action_dict = {} + for motor in G1_29_JointIndex: + target_pos = default_positions[motor.value] + interp_pos = init_dof_pos[motor.value] * (1 - alpha) + target_pos * alpha + action_dict[f"{motor.name}.q"] = float(interp_pos) + + self.send_action(action_dict) + + # Maintain constant control rate + elapsed = time.time() - start_time + sleep_time = max(0, control_dt - elapsed) + time.sleep(sleep_time) logger.info("Reached default position") diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 8eafa8e6d..9a327d986 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -74,6 +74,7 @@ from lerobot.cameras import ( # noqa: F401 ) from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401 from lerobot.configs import parser from lerobot.configs.policies import PreTrainedConfig from lerobot.datasets.image_writer import safe_stop_image_writer @@ -103,6 +104,7 @@ from lerobot.robots import ( # noqa: F401 make_robot_from_config, omx_follower, so_follower, + unitree_g1, ) from lerobot.teleoperators import ( # noqa: F401 Teleoperator, @@ -508,6 +510,11 @@ def record(cfg: RecordConfig) -> LeRobotDataset: (recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"] ): log_say("Reset the environment", cfg.play_sounds) + + # reset g1 robot + if robot.name == "unitree_g1": + robot.reset() + record_loop( robot=robot, events=events, diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index 8e0d9cf6d..c16271932 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -60,6 +60,7 @@ from lerobot.robots import ( # noqa: F401 make_robot_from_config, omx_follower, so_follower, + unitree_g1, ) from lerobot.utils.constants import ACTION from lerobot.utils.import_utils import register_third_party_plugins From b8ec1152d447190e01fbe065f25076c730f20f10 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 12 Jan 2026 18:05:16 +0100 Subject: [PATCH 033/111] fix(robots): add reachy2 fixes (#2783) * fix(robots): add reachy2 fixes * tests(robots): remove reachy sdk stub --- docs/source/reachy2.mdx | 53 +++-- pyproject.toml | 2 +- .../configuration_reachy2_camera.py | 8 +- .../cameras/reachy2_camera/reachy2_camera.py | 183 +++++------------- .../robots/reachy2/configuration_reachy2.py | 34 ++-- src/lerobot/robots/reachy2/robot_reachy2.py | 10 +- src/lerobot/scripts/lerobot_record.py | 3 + src/lerobot/scripts/lerobot_replay.py | 1 + src/lerobot/scripts/lerobot_teleoperate.py | 2 + .../reachy2_teleoperator.py | 51 +++-- src/lerobot/utils/import_utils.py | 1 + tests/cameras/test_reachy2_camera.py | 14 +- tests/conftest.py | 1 - tests/plugins/reachy2_sdk.py | 46 ----- tests/robots/test_reachy2.py | 2 + 15 files changed, 165 insertions(+), 246 deletions(-) delete mode 100644 tests/plugins/reachy2_sdk.py diff --git a/docs/source/reachy2.mdx b/docs/source/reachy2.mdx index 7d3dc1b60..51b09acd2 100644 --- a/docs/source/reachy2.mdx +++ b/docs/source/reachy2.mdx @@ -38,6 +38,7 @@ docker run --rm -it \ start_rviz:=true start_sdk_server:=true mujoco:=true ``` +> [!NOTE] > If MuJoCo runs slowly (low simulation frequency), append `-e LD_LIBRARY_PATH="/opt/host-libs:$LD_LIBRARY_PATH" \` to the previous command to improve performance: > > ``` @@ -141,7 +142,7 @@ If you choose this option but still want to use the VR teleoperation application First add reachy2 and reachy2_teleoperator to the imports of the record script. Then you can use the following command: ```bash -python -m lerobot.record \ +lerobot-record \ --robot.type=reachy2 \ --robot.ip_address=192.168.0.200 \ --robot.id=r2-0000 \ @@ -150,6 +151,7 @@ python -m lerobot.record \ --teleop.type=reachy2_teleoperator \ --teleop.ip_address=192.168.0.200 \ --teleop.with_mobile_base=false \ + --robot.with_torso_camera=true \ --dataset.repo_id=pollen_robotics/record_test \ --dataset.single_task="Reachy 2 recording test" \ --dataset.num_episodes=1 \ @@ -165,7 +167,7 @@ python -m lerobot.record \ **Extended setup overview (all options included):** ```bash -python -m lerobot.record \ +lerobot-record \ --robot.type=reachy2 \ --robot.ip_address=192.168.0.200 \ --robot.use_external_commands=true \ @@ -177,6 +179,8 @@ python -m lerobot.record \ --robot.with_left_teleop_camera=true \ --robot.with_right_teleop_camera=true \ --robot.with_torso_camera=false \ + --robot.camera_width=640 \ + --robot.camera_height=480 \ --robot.disable_torque_on_disconnect=false \ --robot.max_relative_target=5.0 \ --teleop.type=reachy2_teleoperator \ @@ -212,9 +216,10 @@ Must be set to true if a compliant Reachy 2 is used to control another one. From our initial tests, recording **all** joints when only some are moving can reduce model quality with certain policies. To avoid this, you can exclude specific parts from recording and replay using: -```` +```bash --robot.with_=false -```, +``` + with `` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`. It determine whether the corresponding part is recorded in the observations. True if not set. @@ -222,49 +227,60 @@ By default, **all parts are recorded**. The same per-part mechanism is available in `reachy2_teleoperator` as well. -```` - +```bash --teleop.with\_ - ``` + with `` being one of : `mobile_base`, `l_arm`, `r_arm", `neck`, `antennas`. Determine whether the corresponding part is recorded in the actions. True if not set. > **Important:** In a given session, the **enabled parts must match** on both the robot and the teleoperator. -For example, if the robot runs with `--robot.with_mobile_base=false`, the teleoperator must disable the same part `--teleoperator.with_mobile_base=false`. +> For example, if the robot runs with `--robot.with_mobile_base=false`, the teleoperator must disable the same part `--teleoperator.with_mobile_base=false`. ##### Use the relevant cameras -You can do the same for **cameras**. By default, only the **teleoperation cameras** are recorded (both `left_teleop_camera` and `right_teleop_camera`). Enable or disable each camera with: +You can do the same for **cameras**. Enable or disable each camera with default parameters using: +```bash +--robot.with_left_teleop_camera= \ +--robot.with_right_teleop_camera= \ +--robot.with_torso_camera= ``` ---robot.with_left_teleop_camera= ---robot.with_right_teleop_camera= ---robot.with_torso_camera= +By default, no camera is recorded, all camera arguments are set to `false`. +If you want to, you can use custom `width` and `height` parameters for Reachy 2's cameras using the `--robot.camera_width` & `--robot.camera_height` argument: -```` +```bash +--robot.camera_width=1920 \ +--robot.camera_height=1080 +``` +This will change the resolution of all 3 default robot cameras (enabled by the above bool arguments). + +If you want, you can add additional cameras other than the ones in the robot as usual with: + +```bash +--robot.cameras="{ extra: {type: opencv, index_or_path: 42, width: 640, height: 480, fps: 30}}" \ +``` ## Step 2: Replay Make sure the robot is configured with the same parts as the dataset: ```bash -python -m lerobot.replay \ +lerobot-replay \ --robot.type=reachy2 \ --robot.ip_address=192.168.0.200 \ --robot.use_external_commands=false \ --robot.with_mobile_base=false \ --dataset.repo_id=pollen_robotics/record_test \ --dataset.episode=0 - --display_data=true -```` +``` ## Step 3: Train ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --dataset.repo_id=pollen_robotics/record_test \ --policy.type=act \ --output_dir=outputs/train/reachy2_test \ @@ -277,10 +293,9 @@ python -m lerobot.scripts.train \ ## Step 4: Evaluate ```bash -python -m lerobot.record \ +lerobot-eval \ --robot.type=reachy2 \ --robot.ip_address=192.168.0.200 \ - --display_data=false \ --dataset.repo_id=pollen_robotics/eval_record_test \ --dataset.single_task="Evaluate reachy2 policy" \ --dataset.num_episodes=10 \ diff --git a/pyproject.toml b/pyproject.toml index e8f334c77..067cb1df8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,7 +111,7 @@ unitree_g1 = [ "pyzmq>=26.2.1,<28.0.0", "onnxruntime>=1.16.0,<2.0.0" ] -reachy2 = ["reachy2_sdk>=1.0.14,<1.1.0"] +reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"] kinematics = ["lerobot[placo-dep]"] intelrealsense = [ "pyrealsense2>=2.55.1.6486,<2.57.0 ; sys_platform != 'darwin'", diff --git a/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py index f26cf2ad1..ca6db4f03 100644 --- a/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py +++ b/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py @@ -35,18 +35,19 @@ class Reachy2CameraConfig(CameraConfig): name="teleop", image_type="left", ip_address="192.168.0.200", # IP address of the robot - fps=15, + port=50065, # Port of the camera server width=640, height=480, + fps=30, # Not configurable for Reachy 2 cameras color_mode=ColorMode.RGB, - ) # Left teleop camera, 640x480 @ 15FPS + ) # Left teleop camera, 640x480 @ 30FPS ``` Attributes: name: Name of the camera device. Can be "teleop" or "depth". image_type: Type of image stream. For "teleop" camera, can be "left" or "right". For "depth" camera, can be "rgb" or "depth". (depth is not supported yet) - fps: Requested frames per second for the color stream. + fps: Requested frames per second for the color stream. Not configurable for Reachy 2 cameras. width: Requested frame width in pixels for the color stream. height: Requested frame height in pixels for the color stream. color_mode: Color mode for image output (RGB or BGR). Defaults to RGB. @@ -62,7 +63,6 @@ class Reachy2CameraConfig(CameraConfig): color_mode: ColorMode = ColorMode.RGB ip_address: str | None = "localhost" port: int = 50065 - # use_depth: bool = False def __post_init__(self) -> None: if self.name not in ["teleop", "depth"]: diff --git a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py index 30e096767..c8916c5ee 100644 --- a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py +++ b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py @@ -16,12 +16,13 @@ Provides the Reachy2Camera class for capturing frames from Reachy 2 cameras using Reachy 2's CameraManager. """ +from __future__ import annotations + import logging import os import platform import time -from threading import Event, Lock, Thread -from typing import Any +from typing import TYPE_CHECKING, Any from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing @@ -30,10 +31,19 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0" import cv2 # type: ignore # TODO: add type stubs for OpenCV import numpy as np # type: ignore # TODO: add type stubs for numpy -from reachy2_sdk.media.camera import CameraView # type: ignore # TODO: add type stubs for reachy2_sdk -from reachy2_sdk.media.camera_manager import ( # type: ignore # TODO: add type stubs for reachy2_sdk - CameraManager, -) + +from lerobot.utils.import_utils import _reachy2_sdk_available + +if TYPE_CHECKING or _reachy2_sdk_available: + from reachy2_sdk.media.camera import CameraView + from reachy2_sdk.media.camera_manager import CameraManager +else: + CameraManager = None + + class CameraView: + LEFT = 0 + RIGHT = 1 + from lerobot.utils.errors import DeviceNotConnectedError @@ -69,17 +79,10 @@ class Reachy2Camera(Camera): self.config = config - self.fps = config.fps self.color_mode = config.color_mode self.cam_manager: CameraManager | None = None - self.thread: Thread | None = None - self.stop_event: Event | None = None - self.frame_lock: Lock = Lock() - self.latest_frame: NDArray[Any] | None = None - self.new_frame_event: Event = Event() - def __str__(self) -> str: return f"{self.__class__.__name__}({self.config.name}, {self.config.image_type})" @@ -100,44 +103,23 @@ class Reachy2Camera(Camera): def connect(self, warmup: bool = True) -> None: """ Connects to the Reachy2 CameraManager as specified in the configuration. + + Raises: + DeviceNotConnectedError: If the camera is not connected. """ self.cam_manager = CameraManager(host=self.config.ip_address, port=self.config.port) + if self.cam_manager is None: + raise DeviceNotConnectedError(f"Could not connect to {self}.") self.cam_manager.initialize_cameras() logger.info(f"{self} connected.") @staticmethod - def find_cameras(ip_address: str = "localhost", port: int = 50065) -> list[dict[str, Any]]: + def find_cameras() -> list[dict[str, Any]]: """ - Detects available Reachy 2 cameras. - - Returns: - List[Dict[str, Any]]: A list of dictionaries, - where each dictionary contains 'name', 'stereo', - and the default profile properties (width, height, fps). + Detection not implemented for Reachy2 cameras. """ - initialized_cameras = [] - camera_manager = CameraManager(host=ip_address, port=port) - - for camera in [camera_manager.teleop, camera_manager.depth]: - if camera is None: - continue - - height, width, _, _, _, _, _ = camera.get_parameters() - - camera_info = { - "name": camera._cam_info.name, - "stereo": camera._cam_info.stereo, - "default_profile": { - "width": width, - "height": height, - "fps": 30, - }, - } - initialized_cameras.append(camera_info) - - camera_manager.disconnect() - return initialized_cameras + raise NotImplementedError("Camera detection is not implemented for Reachy2 cameras.") def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: """ @@ -155,95 +137,49 @@ class Reachy2Camera(Camera): (height, width, channels), using the specified or default color mode and applying any configured rotation. """ + start_time = time.perf_counter() + if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") - start_time = time.perf_counter() + if self.cam_manager is None: + raise DeviceNotConnectedError(f"{self} is not connected.") frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8) - if self.cam_manager is None: - raise DeviceNotConnectedError(f"{self} is not connected.") + if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"): + if self.config.image_type == "left": + frame = self.cam_manager.teleop.get_frame( + CameraView.LEFT, size=(self.config.width, self.config.height) + )[0] + elif self.config.image_type == "right": + frame = self.cam_manager.teleop.get_frame( + CameraView.RIGHT, size=(self.config.width, self.config.height) + )[0] + elif self.config.name == "depth" and hasattr(self.cam_manager, "depth"): + if self.config.image_type == "depth": + frame = self.cam_manager.depth.get_depth_frame()[0] + elif self.config.image_type == "rgb": + frame = self.cam_manager.depth.get_frame(size=(self.config.width, self.config.height))[0] else: - if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"): - if self.config.image_type == "left": - frame = self.cam_manager.teleop.get_frame(CameraView.LEFT, size=(640, 480))[0] - elif self.config.image_type == "right": - frame = self.cam_manager.teleop.get_frame(CameraView.RIGHT, size=(640, 480))[0] - elif self.config.name == "depth" and hasattr(self.cam_manager, "depth"): - if self.config.image_type == "depth": - frame = self.cam_manager.depth.get_depth_frame()[0] - elif self.config.image_type == "rgb": - frame = self.cam_manager.depth.get_frame(size=(640, 480))[0] + raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.") - if frame is None: - return np.empty((0, 0, 3), dtype=np.uint8) + if frame is None: + return np.empty((0, 0, 3), dtype=np.uint8) - if self.config.color_mode == "rgb": - frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + if self.config.color_mode == "rgb": + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) read_duration_ms = (time.perf_counter() - start_time) * 1e3 logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") return frame - def _read_loop(self) -> None: - """ - Internal loop run by the background thread for asynchronous reading. - - On each iteration: - 1. Reads a color frame - 2. Stores result in latest_frame (thread-safe) - 3. Sets new_frame_event to notify listeners - - Stops on DeviceNotConnectedError, logs other errors and continues. - """ - if self.stop_event is None: - raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.") - - while not self.stop_event.is_set(): - try: - color_image = self.read() - - with self.frame_lock: - self.latest_frame = color_image - self.new_frame_event.set() - - except DeviceNotConnectedError: - break - except Exception as e: - logger.warning(f"Error reading frame in background thread for {self}: {e}") - - def _start_read_thread(self) -> None: - """Starts or restarts the background read thread if it's not running.""" - if self.thread is not None and self.thread.is_alive(): - self.thread.join(timeout=0.1) - if self.stop_event is not None: - self.stop_event.set() - - self.stop_event = Event() - self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop") - self.thread.daemon = True - self.thread.start() - - def _stop_read_thread(self) -> None: - """Signals the background read thread to stop and waits for it to join.""" - if self.stop_event is not None: - self.stop_event.set() - - if self.thread is not None and self.thread.is_alive(): - self.thread.join(timeout=2.0) - - self.thread = None - self.stop_event = None - def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ - Reads the latest available frame asynchronously. + Reads the latest available frame. - This method retrieves the most recent frame captured by the background - read thread. It does not block waiting for the camera hardware directly, - but may wait up to timeout_ms for the background thread to provide a frame. + This method retrieves the most recent frame available in Reachy 2's low-level software. Args: timeout_ms (float): Maximum time in milliseconds to wait for a frame @@ -261,22 +197,10 @@ class Reachy2Camera(Camera): if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") - if self.thread is None or not self.thread.is_alive(): - self._start_read_thread() - - if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0): - thread_alive = self.thread is not None and self.thread.is_alive() - raise TimeoutError( - f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. " - f"Read thread alive: {thread_alive}." - ) - - with self.frame_lock: - frame = self.latest_frame - self.new_frame_event.clear() + frame = self.read() if frame is None: - raise RuntimeError(f"Internal error: Event set but no frame available for {self}.") + raise RuntimeError(f"Internal error: No frame available for {self}.") return frame @@ -287,12 +211,9 @@ class Reachy2Camera(Camera): Raises: DeviceNotConnectedError: If the camera is already disconnected. """ - if not self.is_connected and self.thread is None: + if not self.is_connected: raise DeviceNotConnectedError(f"{self} not connected.") - if self.thread is not None: - self._stop_read_thread() - if self.cam_manager is not None: self.cam_manager.disconnect() diff --git a/src/lerobot/robots/reachy2/configuration_reachy2.py b/src/lerobot/robots/reachy2/configuration_reachy2.py index aa25351c6..63293e675 100644 --- a/src/lerobot/robots/reachy2/configuration_reachy2.py +++ b/src/lerobot/robots/reachy2/configuration_reachy2.py @@ -30,6 +30,8 @@ class Reachy2RobotConfig(RobotConfig): # IP address of the Reachy 2 robot ip_address: str | None = "localhost" + # Port of the Reachy 2 robot + port: int = 50065 # If True, turn_off_smoothly() will be sent to the robot before disconnecting. disable_torque_on_disconnect: bool = False @@ -51,11 +53,16 @@ class Reachy2RobotConfig(RobotConfig): # Robot cameras # Set to True if you want to use the corresponding cameras in the observations. - # By default, only the teleop cameras are used. - with_left_teleop_camera: bool = True - with_right_teleop_camera: bool = True + # By default, no camera is used. + with_left_teleop_camera: bool = False + with_right_teleop_camera: bool = False with_torso_camera: bool = False + # Camera parameters + camera_width: int = 640 + camera_height: int = 480 + + # For cameras other than the 3 default Reachy 2 cameras. cameras: dict[str, CameraConfig] = field(default_factory=dict) def __post_init__(self) -> None: @@ -65,9 +72,10 @@ class Reachy2RobotConfig(RobotConfig): name="teleop", image_type="left", ip_address=self.ip_address, - fps=15, - width=640, - height=480, + port=self.port, + width=self.camera_width, + height=self.camera_height, + fps=30, # Not configurable for Reachy 2 cameras color_mode=ColorMode.RGB, ) if self.with_right_teleop_camera: @@ -75,9 +83,10 @@ class Reachy2RobotConfig(RobotConfig): name="teleop", image_type="right", ip_address=self.ip_address, - fps=15, - width=640, - height=480, + port=self.port, + width=self.camera_width, + height=self.camera_height, + fps=30, # Not configurable for Reachy 2 cameras color_mode=ColorMode.RGB, ) if self.with_torso_camera: @@ -85,9 +94,10 @@ class Reachy2RobotConfig(RobotConfig): name="depth", image_type="rgb", ip_address=self.ip_address, - fps=15, - width=640, - height=480, + port=self.port, + width=self.camera_width, + height=self.camera_height, + fps=30, # Not configurable for Reachy 2 cameras color_mode=ColorMode.RGB, ) diff --git a/src/lerobot/robots/reachy2/robot_reachy2.py b/src/lerobot/robots/reachy2/robot_reachy2.py index ecc488a79..74742ee8d 100644 --- a/src/lerobot/robots/reachy2/robot_reachy2.py +++ b/src/lerobot/robots/reachy2/robot_reachy2.py @@ -13,19 +13,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import time -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np -from reachy2_sdk import ReachySDK from lerobot.cameras.utils import make_cameras_from_configs +from lerobot.utils.import_utils import _reachy2_sdk_available from ..robot import Robot from ..utils import ensure_safe_goal_position from .configuration_reachy2 import Reachy2RobotConfig +if TYPE_CHECKING or _reachy2_sdk_available: + from reachy2_sdk import ReachySDK +else: + ReachySDK = None + # {lerobot_keys: reachy2_sdk_keys} REACHY2_NECK_JOINTS = { "neck_yaw.pos": "head.neck.yaw", diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 9a327d986..f03776989 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -73,6 +73,7 @@ from lerobot.cameras import ( # noqa: F401 CameraConfig, # noqa: F401 ) from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 +from lerobot.cameras.reachy2_camera.configuration_reachy2_camera import Reachy2CameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401 from lerobot.configs import parser @@ -103,6 +104,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, make_robot_from_config, omx_follower, + reachy2, so_follower, unitree_g1, ) @@ -114,6 +116,7 @@ from lerobot.teleoperators import ( # noqa: F401 koch_leader, make_teleoperator_from_config, omx_leader, + reachy2_teleoperator, so_leader, ) from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index c16271932..49c06d643 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -59,6 +59,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, make_robot_from_config, omx_follower, + reachy2, so_follower, unitree_g1, ) diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index f2392bb51..05d4534d4 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -76,6 +76,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, make_robot_from_config, omx_follower, + reachy2, so_follower, ) from lerobot.teleoperators import ( # noqa: F401 @@ -88,6 +89,7 @@ from lerobot.teleoperators import ( # noqa: F401 koch_leader, make_teleoperator_from_config, omx_leader, + reachy2_teleoperator, so_leader, ) from lerobot.utils.import_utils import register_third_party_plugins diff --git a/src/lerobot/teleoperators/reachy2_teleoperator/reachy2_teleoperator.py b/src/lerobot/teleoperators/reachy2_teleoperator/reachy2_teleoperator.py index 5a427dd71..578aaa7b2 100644 --- a/src/lerobot/teleoperators/reachy2_teleoperator/reachy2_teleoperator.py +++ b/src/lerobot/teleoperators/reachy2_teleoperator/reachy2_teleoperator.py @@ -13,11 +13,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import logging import time +from typing import TYPE_CHECKING -from reachy2_sdk import ReachySDK +from lerobot.utils.import_utils import _reachy2_sdk_available + +if TYPE_CHECKING or _reachy2_sdk_available: + from reachy2_sdk import ReachySDK +else: + ReachySDK = None + +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..teleoperator import Teleoperator from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig @@ -75,6 +84,7 @@ class Reachy2Teleoperator(Teleoperator): def __init__(self, config: Reachy2TeleoperatorConfig): super().__init__(config) + self.config = config self.reachy: None | ReachySDK = None @@ -117,9 +127,13 @@ class Reachy2Teleoperator(Teleoperator): return self.reachy.is_connected() if self.reachy is not None else False def connect(self, calibrate: bool = True) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + self.reachy = ReachySDK(self.config.ip_address) + if not self.is_connected: - raise ConnectionError() + raise DeviceNotConnectedError() logger.info(f"{self} connected.") @property @@ -135,23 +149,24 @@ class Reachy2Teleoperator(Teleoperator): def get_action(self) -> dict[str, float]: start = time.perf_counter() - if self.reachy and self.is_connected: - if self.config.use_present_position: - joint_action = { - k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items() - } - else: - joint_action = {k: self.reachy.joints[v].goal_position for k, v in self.joints_dict.items()} + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") - if not self.config.with_mobile_base: - dt_ms = (time.perf_counter() - start) * 1e3 - logger.debug(f"{self} read action: {dt_ms:.1f}ms") - return joint_action + joint_action: dict[str, float] = {} + vel_action: dict[str, float] = {} - if self.config.use_present_position: - vel_action = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()} - else: - vel_action = {k: self.reachy.mobile_base.last_cmd_vel[v] for k, v in REACHY2_VEL.items()} + if self.config.use_present_position: + joint_action = {k: self.reachy.joints[v].present_position for k, v in self.joints_dict.items()} + else: + joint_action = {k: self.reachy.joints[v].goal_position for k, v in self.joints_dict.items()} + if not self.config.with_mobile_base: + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read action: {dt_ms:.1f}ms") + return joint_action + if self.config.use_present_position: + vel_action = {k: self.reachy.mobile_base.odometry[v] for k, v in REACHY2_VEL.items()} + else: + vel_action = {k: self.reachy.mobile_base.last_cmd_vel[v] for k, v in REACHY2_VEL.items()} dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read action: {dt_ms:.1f}ms") return {**joint_action, **vel_action} @@ -160,5 +175,5 @@ class Reachy2Teleoperator(Teleoperator): raise NotImplementedError def disconnect(self) -> None: - if self.reachy and self.is_connected: + if self.is_connected: self.reachy.disconnect() diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index e6817ba6c..0206a8ac9 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -64,6 +64,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b _transformers_available = is_package_available("transformers") _peft_available = is_package_available("peft") _scipy_available = is_package_available("scipy") +_reachy2_sdk_available = is_package_available("reachy2_sdk") def make_device_from_device_class(config: ChoiceRegistry) -> Any: diff --git a/tests/cameras/test_reachy2_camera.py b/tests/cameras/test_reachy2_camera.py index 0b38e8b0b..14774bf38 100644 --- a/tests/cameras/test_reachy2_camera.py +++ b/tests/cameras/test_reachy2_camera.py @@ -20,6 +20,8 @@ from unittest.mock import MagicMock, patch import numpy as np import pytest +pytest.importorskip("reachy2_sdk") + from lerobot.cameras.reachy2_camera import Reachy2Camera, Reachy2CameraConfig from lerobot.utils.errors import DeviceNotConnectedError @@ -127,24 +129,12 @@ def test_async_read(camera): try: img = camera.async_read() - assert camera.thread is not None - assert camera.thread.is_alive() assert isinstance(img, np.ndarray) finally: if camera.is_connected: camera.disconnect() -def test_async_read_timeout(camera): - camera.connect() - try: - with pytest.raises(TimeoutError): - camera.async_read(timeout_ms=0) - finally: - if camera.is_connected: - camera.disconnect() - - def test_read_before_connect(camera): with pytest.raises(DeviceNotConnectedError): _ = camera.read() diff --git a/tests/conftest.py b/tests/conftest.py index b14e9aed5..2fcf878ab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,7 +28,6 @@ pytest_plugins = [ "tests.fixtures.files", "tests.fixtures.hub", "tests.fixtures.optimizers", - "tests.plugins.reachy2_sdk", ] diff --git a/tests/plugins/reachy2_sdk.py b/tests/plugins/reachy2_sdk.py deleted file mode 100644 index 457fcf0f9..000000000 --- a/tests/plugins/reachy2_sdk.py +++ /dev/null @@ -1,46 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import sys -import types -from unittest.mock import MagicMock - - -def _install_reachy2_sdk_stub(): - sdk = types.ModuleType("reachy2_sdk") - sdk.__path__ = [] - sdk.ReachySDK = MagicMock(name="ReachySDK") - - media = types.ModuleType("reachy2_sdk.media") - media.__path__ = [] - camera = types.ModuleType("reachy2_sdk.media.camera") - camera.CameraView = MagicMock(name="CameraView") - camera_manager = types.ModuleType("reachy2_sdk.media.camera_manager") - camera_manager.CameraManager = MagicMock(name="CameraManager") - - sdk.media = media - media.camera = camera - media.camera_manager = camera_manager - - # Register in sys.modules - sys.modules.setdefault("reachy2_sdk", sdk) - sys.modules.setdefault("reachy2_sdk.media", media) - sys.modules.setdefault("reachy2_sdk.media.camera", camera) - sys.modules.setdefault("reachy2_sdk.media.camera_manager", camera_manager) - - -def pytest_sessionstart(session): - _install_reachy2_sdk_stub() diff --git a/tests/robots/test_reachy2.py b/tests/robots/test_reachy2.py index 94152ea38..d3c44bf5a 100644 --- a/tests/robots/test_reachy2.py +++ b/tests/robots/test_reachy2.py @@ -19,6 +19,8 @@ from unittest.mock import MagicMock, patch import numpy as np import pytest +pytest.importorskip("reachy2_sdk") + from lerobot.robots.reachy2 import ( REACHY2_ANTENNAS_JOINTS, REACHY2_L_ARM_JOINTS, From d0f57f58d1e24688664863040ef24cb8a1b37374 Mon Sep 17 00:00:00 2001 From: samet-rob Date: Mon, 12 Jan 2026 21:24:19 +0300 Subject: [PATCH 034/111] Move cfg.validate() earlier to fix NoneType error with --policy.path (#2782) --- src/lerobot/scripts/lerobot_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 927e9659a..55737d5d8 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -251,6 +251,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): cfg: A `TrainPipelineConfig` object containing all training configurations. accelerator: Optional Accelerator instance. If None, one will be created automatically. """ + cfg.validate() + # Create Accelerator if not provided # It will automatically detect if running in distributed mode or single-process mode # We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes @@ -274,8 +276,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): # When using accelerate, only the main process should log to avoid duplicate outputs is_main_process = accelerator.is_main_process - cfg.validate() - # Only log on main process if is_main_process: logging.info(pformat(cfg.to_dict())) From 2cdd9f43f7e104702ed74c9b6e809d079527314e Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Tue, 13 Jan 2026 01:42:53 +0100 Subject: [PATCH 035/111] fix: train tokenizer CLI entry point (#2784) --- .../scripts/lerobot_train_tokenizer.py | 177 +++++++++++------- 1 file changed, 105 insertions(+), 72 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train_tokenizer.py b/src/lerobot/scripts/lerobot_train_tokenizer.py index 03bfcaaf8..296447bad 100644 --- a/src/lerobot/scripts/lerobot_train_tokenizer.py +++ b/src/lerobot/scripts/lerobot_train_tokenizer.py @@ -14,11 +14,14 @@ """Train FAST tokenizer for action encoding. This script: -1. Loads action chunks from LeRobotDataset (with sampling) -2. Applies delta transforms and per-timestamp normalization -3. Trains FAST tokenizer on specified action dimensions -4. Saves tokenizer to assets directory -5. Reports compression statistics +1. Loads action chunks from LeRobotDataset (with episode sampling) +2. Optionally applies delta transforms (relative vs absolute actions) +3. Extracts specified action dimensions for encoding +4. Applies normalization (MEAN_STD, MIN_MAX, QUANTILES, or other modes) +5. Trains FAST tokenizer (BPE on DCT coefficients) on the action chunks +6. Saves tokenizer to output directory +7. Optionally pushes tokenizer to Hugging Face Hub +8. Reports compression statistics Example: @@ -42,18 +45,64 @@ lerobot-train-tokenizer \ """ import json +from dataclasses import dataclass from pathlib import Path +from typing import TYPE_CHECKING import numpy as np import torch -import tyro from huggingface_hub import HfApi -from transformers import AutoProcessor +from lerobot.utils.import_utils import _transformers_available + +if TYPE_CHECKING or _transformers_available: + from transformers import AutoProcessor +else: + AutoProcessor = None + +from lerobot.configs import parser from lerobot.configs.types import NormalizationMode from lerobot.datasets.lerobot_dataset import LeRobotDataset +@dataclass +class TokenizerTrainingConfig: + """Configuration for training FAST tokenizer.""" + + # LeRobot dataset repository ID + repo_id: str + # Root directory for dataset (default: ~/.cache/huggingface/lerobot) + root: str | None = None + # Number of future actions in each chunk + action_horizon: int = 10 + # Max episodes to use (None = all episodes in dataset) + max_episodes: int | None = None + # Fraction of chunks to sample per episode + sample_fraction: float = 0.1 + # Comma-separated dimension ranges to encode (e.g., "0:6,7:23") + encoded_dims: str = "0:6,7:23" + # Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5") + delta_dims: str | None = None + # Whether to apply delta transform (relative actions vs absolute actions) + use_delta_transform: bool = False + # Dataset key for state observations (default: "observation.state") + state_key: str = "observation.state" + # Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY) + normalization_mode: str = "QUANTILES" + # FAST vocabulary size (BPE vocab size) + vocab_size: int = 1024 + # DCT scaling factor (default: 10.0) + scale: float = 10.0 + # Directory to save tokenizer (default: ./fast_tokenizer_{repo_id}) + output_dir: str | None = None + # Whether to push the tokenizer to Hugging Face Hub + push_to_hub: bool = False + # Hub repository ID (e.g., "username/tokenizer-name"). If None, uses output_dir name + hub_repo_id: str | None = None + # Whether to create a private repository on the Hub + hub_private: bool = False + + def apply_delta_transform(state: np.ndarray, actions: np.ndarray, delta_dims: list[int] | None) -> np.ndarray: """Apply delta transform to specified dimensions. @@ -327,88 +376,57 @@ def compute_compression_stats(tokenizer, action_chunks: np.ndarray): return stats -def main( - repo_id: str, - root: str | None = None, - action_horizon: int = 10, - max_episodes: int | None = None, - sample_fraction: float = 0.1, - encoded_dims: str = "0:6,7:23", - delta_dims: str | None = None, - use_delta_transform: bool = False, - state_key: str = "observation.state", - normalization_mode: str = "QUANTILES", - vocab_size: int = 1024, - scale: float = 10.0, - output_dir: str | None = None, - push_to_hub: bool = False, - hub_repo_id: str | None = None, - hub_private: bool = False, -): +@parser.wrap() +def train_tokenizer(cfg: TokenizerTrainingConfig): """ Train FAST tokenizer for action encoding. Args: - repo_id: LeRobot dataset repository ID - root: Root directory for dataset (default: ~/.cache/huggingface/lerobot) - action_horizon: Number of future actions in each chunk - max_episodes: Max episodes to use (None = all episodes in dataset) - sample_fraction: Fraction of chunks to sample per episode - encoded_dims: Comma-separated dimension ranges to encode (e.g., "0:6,7:23") - delta_dims: Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5") - use_delta_transform: Whether to apply delta transform (relative actions vs absolute actions) - state_key: Dataset key for state observations (default: "observation.state") - normalization_mode: Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY) - vocab_size: FAST vocabulary size (BPE vocab size) - scale: DCT scaling factor (default: 10.0) - output_dir: Directory to save tokenizer (default: ./fast_tokenizer_{repo_id}) - push_to_hub: Whether to push the tokenizer to Hugging Face Hub - hub_repo_id: Hub repository ID (e.g., "username/tokenizer-name"). If None, uses output_dir name - hub_private: Whether to create a private repository on the Hub + cfg: TokenizerTrainingConfig dataclass with all configuration parameters """ # load dataset - print(f"Loading dataset: {repo_id}") - dataset = LeRobotDataset(repo_id=repo_id, root=root) + print(f"Loading dataset: {cfg.repo_id}") + dataset = LeRobotDataset(repo_id=cfg.repo_id, root=cfg.root) print(f"Dataset loaded: {dataset.num_episodes} episodes, {dataset.num_frames} frames") # parse normalization mode try: - norm_mode = NormalizationMode(normalization_mode) + norm_mode = NormalizationMode(cfg.normalization_mode) except ValueError as err: raise ValueError( - f"Invalid normalization_mode: {normalization_mode}. " + f"Invalid normalization_mode: {cfg.normalization_mode}. " f"Must be one of: {', '.join([m.value for m in NormalizationMode])}" ) from err print(f"Normalization mode: {norm_mode.value}") # parse encoded dimensions encoded_dim_ranges = [] - for range_str in encoded_dims.split(","): + for range_str in cfg.encoded_dims.split(","): start, end = map(int, range_str.strip().split(":")) encoded_dim_ranges.append((start, end)) total_encoded_dims = sum(end - start for start, end in encoded_dim_ranges) - print(f"Encoding {total_encoded_dims} dimensions: {encoded_dims}") + print(f"Encoding {total_encoded_dims} dimensions: {cfg.encoded_dims}") # parse delta dimensions delta_dim_list = None - if delta_dims is not None and delta_dims.strip(): - delta_dim_list = [int(d.strip()) for d in delta_dims.split(",")] + if cfg.delta_dims is not None and cfg.delta_dims.strip(): + delta_dim_list = [int(d.strip()) for d in cfg.delta_dims.split(",")] print(f"Delta dimensions: {delta_dim_list}") else: print("No delta dimensions specified") - print(f"Use delta transform: {use_delta_transform}") - if use_delta_transform and (delta_dim_list is None or len(delta_dim_list) == 0): + print(f"Use delta transform: {cfg.use_delta_transform}") + if cfg.use_delta_transform and (delta_dim_list is None or len(delta_dim_list) == 0): print("Warning: use_delta_transform=True but no delta_dims specified. No delta will be applied.") - print(f"Action horizon: {action_horizon}") - print(f"State key: {state_key}") + print(f"Action horizon: {cfg.action_horizon}") + print(f"State key: {cfg.state_key}") # determine episodes to process num_episodes = dataset.num_episodes - if max_episodes is not None: - num_episodes = min(max_episodes, num_episodes) + if cfg.max_episodes is not None: + num_episodes = min(cfg.max_episodes, num_episodes) print(f"Processing {num_episodes} episodes...") @@ -419,7 +437,15 @@ def main( print(f" Processing episode {ep_idx}/{num_episodes}...") chunks = process_episode( - (dataset, ep_idx, action_horizon, delta_dim_list, sample_fraction, state_key, use_delta_transform) + ( + dataset, + ep_idx, + cfg.action_horizon, + delta_dim_list, + cfg.sample_fraction, + cfg.state_key, + cfg.use_delta_transform, + ) ) if chunks is not None: all_chunks.append(chunks) @@ -495,16 +521,17 @@ def main( # train FAST tokenizer tokenizer = train_fast_tokenizer( encoded_chunks, - vocab_size=vocab_size, - scale=scale, + vocab_size=cfg.vocab_size, + scale=cfg.scale, ) # compute compression statistics compression_stats = compute_compression_stats(tokenizer, encoded_chunks) # save tokenizer + output_dir = cfg.output_dir if output_dir is None: - output_dir = f"fast_tokenizer_{repo_id.replace('/', '_')}" + output_dir = f"fast_tokenizer_{cfg.repo_id.replace('/', '_')}" output_path = Path(output_dir) output_path.mkdir(parents=True, exist_ok=True) @@ -512,18 +539,18 @@ def main( # save metadata metadata = { - "repo_id": repo_id, - "vocab_size": vocab_size, - "scale": scale, - "encoded_dims": encoded_dims, + "repo_id": cfg.repo_id, + "vocab_size": cfg.vocab_size, + "scale": cfg.scale, + "encoded_dims": cfg.encoded_dims, "encoded_dim_ranges": encoded_dim_ranges, "total_encoded_dims": total_encoded_dims, - "delta_dims": delta_dims, + "delta_dims": cfg.delta_dims, "delta_dim_list": delta_dim_list, - "use_delta_transform": use_delta_transform, - "state_key": state_key, + "use_delta_transform": cfg.use_delta_transform, + "state_key": cfg.state_key, "normalization_mode": norm_mode.value, - "action_horizon": action_horizon, + "action_horizon": cfg.action_horizon, "num_training_chunks": len(encoded_chunks), "compression_stats": compression_stats, } @@ -535,21 +562,22 @@ def main( print(f"Metadata: {json.dumps(metadata, indent=2)}") # push to Hugging Face Hub if requested - if push_to_hub: + if cfg.push_to_hub: # determine the hub repository ID + hub_repo_id = cfg.hub_repo_id if hub_repo_id is None: hub_repo_id = output_path.name print(f"\nNo hub_repo_id provided, using: {hub_repo_id}") print(f"\nPushing tokenizer to Hugging Face Hub: {hub_repo_id}") - print(f" Private: {hub_private}") + print(f" Private: {cfg.hub_private}") try: # use the tokenizer's push_to_hub method tokenizer.push_to_hub( repo_id=hub_repo_id, - private=hub_private, - commit_message=f"Upload FAST tokenizer trained on {repo_id}", + private=cfg.hub_private, + commit_message=f"Upload FAST tokenizer trained on {cfg.repo_id}", ) # also upload the metadata.json file separately @@ -568,5 +596,10 @@ def main( print(" Make sure you're logged in with `huggingface-cli login`") +def main(): + """CLI entry point that parses arguments and runs the tokenizer training.""" + train_tokenizer() + + if __name__ == "__main__": - tyro.cli(main) + main() From 15724826dd5b18c18656151a02a24ce5fd690c46 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 13 Jan 2026 09:49:46 +0100 Subject: [PATCH 036/111] chore: use alias & constants (#2785) * chore: use alias and constants * fix(rl): solve circular dependecy * chore: nit right constant * chore: pre-commit * chore(script): conflict tokenizer train --------- Signed-off-by: Steven Palma --- src/lerobot/async_inference/robot_client.py | 4 +-- src/lerobot/datasets/pipeline_features.py | 4 +-- src/lerobot/envs/libero.py | 6 ++-- src/lerobot/envs/metaworld.py | 12 ++++---- src/lerobot/envs/utils.py | 3 +- src/lerobot/policies/factory.py | 8 ++++-- .../policies/groot/configuration_groot.py | 13 +++++---- src/lerobot/policies/groot/groot_n1.py | 6 ++-- src/lerobot/policies/groot/modeling_groot.py | 3 +- src/lerobot/policies/groot/processor_groot.py | 28 +++++++++++-------- src/lerobot/policies/pi0/configuration_pi0.py | 10 +++---- .../policies/pi05/configuration_pi05.py | 11 ++++---- .../pi0_fast/configuration_pi0_fast.py | 11 ++++---- .../policies/sarm/configuration_sarm.py | 5 ++-- src/lerobot/policies/sarm/modeling_sarm.py | 3 +- src/lerobot/policies/utils.py | 3 +- .../policies/wall_x/configuration_wall_x.py | 13 +++++---- .../policies/wall_x/modeling_wall_x.py | 6 ++-- src/lerobot/policies/xvla/processor_xvla.py | 16 ++++++----- src/lerobot/processor/__init__.py | 3 -- src/lerobot/processor/converters.py | 6 ++-- src/lerobot/processor/core.py | 2 +- src/lerobot/processor/env_processor.py | 11 ++++---- src/lerobot/processor/normalize_processor.py | 4 +-- src/lerobot/processor/pipeline.py | 6 ++-- src/lerobot/processor/tokenizer_processor.py | 4 +-- src/lerobot/rl/gym_manipulator.py | 11 ++++---- .../joint_observations_processor.py | 0 .../robots/bi_so_follower/bi_so_follower.py | 6 ++-- .../robot_earthrover_mini_plus.py | 11 ++++---- src/lerobot/robots/hope_jr/hope_jr_arm.py | 6 ++-- src/lerobot/robots/hope_jr/hope_jr_hand.py | 6 ++-- .../robots/koch_follower/koch_follower.py | 10 +++---- src/lerobot/robots/lekiwi/lekiwi.py | 7 +++-- src/lerobot/robots/lekiwi/lekiwi_client.py | 19 ++++++------- .../robots/omx_follower/omx_follower.py | 10 +++---- src/lerobot/robots/reachy2/robot_reachy2.py | 9 +++--- src/lerobot/robots/robot.py | 12 ++++---- .../so_follower/robot_kinematic_processor.py | 3 +- src/lerobot/robots/so_follower/so_follower.py | 9 +++--- src/lerobot/robots/unitree_g1/unitree_g1.py | 5 ++-- src/lerobot/scripts/lerobot_eval.py | 4 +-- src/lerobot/scripts/lerobot_teleoperate.py | 2 +- .../scripts/lerobot_train_tokenizer.py | 13 ++++----- .../teleoperators/gamepad/teleop_gamepad.py | 4 ++- .../teleoperators/keyboard/teleop_keyboard.py | 9 +++--- src/lerobot/teleoperators/teleoperator.py | 5 ++-- src/lerobot/utils/constants.py | 2 ++ src/lerobot/utils/visualization_utils.py | 11 ++++---- tests/mocks/mock_robot.py | 6 ++-- tests/mocks/mock_teleop.py | 3 +- 51 files changed, 206 insertions(+), 178 deletions(-) rename src/lerobot/{processor => rl}/joint_observations_processor.py (100%) diff --git a/src/lerobot/async_inference/robot_client.py b/src/lerobot/async_inference/robot_client.py index eea5585b0..f26639dc1 100644 --- a/src/lerobot/async_inference/robot_client.py +++ b/src/lerobot/async_inference/robot_client.py @@ -40,7 +40,6 @@ from collections.abc import Callable from dataclasses import asdict from pprint import pformat from queue import Queue -from typing import Any import draccus import grpc @@ -48,6 +47,7 @@ import torch from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.processor import RobotAction from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -351,7 +351,7 @@ class RobotClient: action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)} return action - def control_loop_action(self, verbose: bool = False) -> dict[str, Any]: + def control_loop_action(self, verbose: bool = False) -> RobotAction: """Reading and performing actions in local queue""" # Lock only for queue operations diff --git a/src/lerobot/datasets/pipeline_features.py b/src/lerobot/datasets/pipeline_features.py index 4fad7bd20..161633f26 100644 --- a/src/lerobot/datasets/pipeline_features.py +++ b/src/lerobot/datasets/pipeline_features.py @@ -18,12 +18,12 @@ from typing import Any from lerobot.configs.types import PipelineFeatureType from lerobot.datasets.utils import hw_to_dataset_features -from lerobot.processor import DataProcessorPipeline +from lerobot.processor import DataProcessorPipeline, RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE, OBS_STR def create_initial_features( - action: dict[str, Any] | None = None, observation: dict[str, Any] | None = None + action: RobotAction | None = None, observation: RobotObservation | None = None ) -> dict[PipelineFeatureType, dict[str, Any]]: """ Creates the initial features dict for the dataset from action and observation specs. diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index b1eb37377..74882ad18 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -29,6 +29,8 @@ from gymnasium import spaces from libero.libero import benchmark, get_libero_path from libero.libero.envs import OffScreenRenderEnv +from lerobot.processor import RobotObservation + def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]: """Normalize camera_name into a non-empty list of strings.""" @@ -237,7 +239,7 @@ class LiberoEnv(gym.Env): env.reset() return env - def _format_raw_obs(self, raw_obs: dict[str, Any]) -> dict[str, Any]: + def _format_raw_obs(self, raw_obs: RobotObservation) -> RobotObservation: images = {} for camera_name in self.camera_name: image = raw_obs[camera_name] @@ -313,7 +315,7 @@ class LiberoEnv(gym.Env): info = {"is_success": False} return observation, info - def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]: if action.ndim != 1: raise ValueError( f"Expected action to be 1-D (shape (action_dim,)), " diff --git a/src/lerobot/envs/metaworld.py b/src/lerobot/envs/metaworld.py index 9190f33ad..4d91e002d 100644 --- a/src/lerobot/envs/metaworld.py +++ b/src/lerobot/envs/metaworld.py @@ -25,6 +25,8 @@ import metaworld.policies as policies import numpy as np from gymnasium import spaces +from lerobot.processor import RobotObservation + # ---- Load configuration data from the external JSON file ---- CONFIG_PATH = Path(__file__).parent / "metaworld_config.json" try: @@ -161,7 +163,7 @@ class MetaworldEnv(gym.Env): env._freeze_rand_vec = False # otherwise no randomization return env - def _format_raw_obs(self, raw_obs: np.ndarray) -> dict[str, Any]: + def _format_raw_obs(self, raw_obs: np.ndarray) -> RobotObservation: image = None if self._env is not None: image = self._env.render() @@ -196,7 +198,7 @@ class MetaworldEnv(gym.Env): self, seed: int | None = None, **kwargs, - ) -> tuple[dict[str, Any], dict[str, Any]]: + ) -> tuple[RobotObservation, dict[str, Any]]: """ Reset the environment to its initial state. @@ -204,7 +206,7 @@ class MetaworldEnv(gym.Env): seed (Optional[int]): Random seed for environment initialization. Returns: - observation (Dict[str, Any]): The initial formatted observation. + observation (RobotObservation): The initial formatted observation. info (Dict[str, Any]): Additional info about the reset state. """ super().reset(seed=seed) @@ -216,7 +218,7 @@ class MetaworldEnv(gym.Env): info = {"is_success": False} return observation, info - def step(self, action: np.ndarray) -> tuple[dict[str, Any], float, bool, bool, dict[str, Any]]: + def step(self, action: np.ndarray) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]: """ Perform one environment step. @@ -224,7 +226,7 @@ class MetaworldEnv(gym.Env): action (np.ndarray): The action to execute, must be 1-D with shape (action_dim,). Returns: - observation (Dict[str, Any]): The formatted observation after the step. + observation (RobotObservation): The formatted observation after the step. reward (float): The scalar reward for this step. terminated (bool): Whether the episode terminated successfully. truncated (bool): Whether the episode was truncated due to a time limit. diff --git a/src/lerobot/envs/utils.py b/src/lerobot/envs/utils.py index af814c92a..09431a18d 100644 --- a/src/lerobot/envs/utils.py +++ b/src/lerobot/envs/utils.py @@ -29,6 +29,7 @@ from torch import Tensor from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.envs.configs import EnvConfig +from lerobot.processor import RobotObservation from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR from lerobot.utils.utils import get_channel_first_image_shape @@ -152,7 +153,7 @@ def check_env_attributes_and_types(env: gym.vector.VectorEnv) -> None: ) -def add_envs_task(env: gym.vector.VectorEnv, observation: dict[str, Any]) -> dict[str, Any]: +def add_envs_task(env: gym.vector.VectorEnv, observation: RobotObservation) -> RobotObservation: """Adds task feature to the observation dict with respect to the first environment attribute.""" if hasattr(env.envs[0], "task_description"): task_result = env.call("task_description") diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index fff08ad37..a593e5bcb 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -51,7 +51,11 @@ from lerobot.processor.converters import ( transition_to_batch, transition_to_policy_action, ) -from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME +from lerobot.utils.constants import ( + ACTION, + POLICY_POSTPROCESSOR_DEFAULT_NAME, + POLICY_PREPROCESSOR_DEFAULT_NAME, +) def get_policy_class(name: str) -> type[PreTrainedPolicy]: @@ -250,7 +254,7 @@ def make_pre_post_processors( } # Also ensure postprocessing slices to env action dim and unnormalizes with dataset stats - env_action_dim = policy_cfg.output_features["action"].shape[0] + env_action_dim = policy_cfg.output_features[ACTION].shape[0] postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = { "stats": kwargs.get("dataset_stats"), "normalize_min_max": True, diff --git a/src/lerobot/policies/groot/configuration_groot.py b/src/lerobot/policies/groot/configuration_groot.py index 8002c69ea..4f3d78222 100644 --- a/src/lerobot/policies/groot/configuration_groot.py +++ b/src/lerobot/policies/groot/configuration_groot.py @@ -20,6 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig +from lerobot.utils.constants import ACTION, OBS_STATE @PreTrainedConfig.register_subclass("groot") @@ -137,14 +138,14 @@ class GrootConfig(PreTrainedConfig): "No features of type FeatureType.VISUAL found in input_features." ) - if "observation.state" not in self.input_features: + if OBS_STATE not in self.input_features: state_feature = PolicyFeature( type=FeatureType.STATE, shape=(self.max_state_dim,), ) - self.input_features["observation.state"] = state_feature + self.input_features[OBS_STATE] = state_feature else: - state_shape = self.input_features["observation.state"].shape + state_shape = self.input_features[OBS_STATE].shape state_dim = state_shape[0] if state_shape else 0 if state_dim > self.max_state_dim: raise ValueError( @@ -152,14 +153,14 @@ class GrootConfig(PreTrainedConfig): f"Either reduce state dimension or increase max_state_dim in config." ) - if "action" not in self.output_features: + if ACTION not in self.output_features: action_feature = PolicyFeature( type=FeatureType.ACTION, shape=(self.max_action_dim,), ) - self.output_features["action"] = action_feature + self.output_features[ACTION] = action_feature else: - action_shape = self.output_features["action"].shape + action_shape = self.output_features[ACTION].shape action_dim = action_shape[0] if action_shape else 0 if action_dim > self.max_action_dim: raise ValueError( diff --git a/src/lerobot/policies/groot/groot_n1.py b/src/lerobot/policies/groot/groot_n1.py index 0da26874e..06ff5a04d 100644 --- a/src/lerobot/policies/groot/groot_n1.py +++ b/src/lerobot/policies/groot/groot_n1.py @@ -46,7 +46,7 @@ from lerobot.policies.groot.action_head.flow_matching_action_head import ( FlowmatchingActionHeadConfig, ) from lerobot.policies.groot.utils import ensure_eagle_cache_ready -from lerobot.utils.constants import HF_LEROBOT_HOME +from lerobot.utils.constants import ACTION, HF_LEROBOT_HOME DEFAULT_VENDOR_EAGLE_PATH = str((Path(__file__).resolve().parent / "eagle2_hg_model").resolve()) DEFAULT_TOKENIZER_ASSETS_REPO = "lerobot/eagle2hg-processor-groot-n1p5" @@ -227,8 +227,8 @@ class GR00TN15(PreTrainedModel): detected_error = False error_msg = ERROR_MSG - if "action" in inputs: - action = inputs["action"] + if ACTION in inputs: + action = inputs[ACTION] # In inference, action may be omitted or None; validate only when it's a tensor. if action is None: pass # allow None during inference diff --git a/src/lerobot/policies/groot/modeling_groot.py b/src/lerobot/policies/groot/modeling_groot.py index bdaef37b9..fd9baa9b1 100644 --- a/src/lerobot/policies/groot/modeling_groot.py +++ b/src/lerobot/policies/groot/modeling_groot.py @@ -41,6 +41,7 @@ from torch import Tensor from lerobot.policies.groot.configuration_groot import GrootConfig from lerobot.policies.groot.groot_n1 import GR00TN15 from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.utils.constants import ACTION class GrootPolicy(PreTrainedPolicy): @@ -147,7 +148,7 @@ class GrootPolicy(PreTrainedPolicy): actions = outputs.get("action_pred") - original_action_dim = self.config.output_features["action"].shape[0] + original_action_dim = self.config.output_features[ACTION].shape[0] actions = actions[:, :, :original_action_dim] return actions diff --git a/src/lerobot/policies/groot/processor_groot.py b/src/lerobot/policies/groot/processor_groot.py index d87e43c11..14149cf2f 100644 --- a/src/lerobot/policies/groot/processor_groot.py +++ b/src/lerobot/policies/groot/processor_groot.py @@ -51,7 +51,11 @@ from lerobot.processor.converters import ( ) from lerobot.processor.core import EnvTransition, TransitionKey from lerobot.utils.constants import ( + ACTION, HF_LEROBOT_HOME, + OBS_IMAGE, + OBS_IMAGES, + OBS_STATE, POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME, ) @@ -107,9 +111,9 @@ def make_groot_pre_post_processors( # Define feature specs for optional normalization steps _features: dict[str, PolicyFeature] = { # Observation features (only add those we may normalize) - "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_horizon, max_state_dim)), + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_horizon, max_state_dim)), # Action feature - "action": PolicyFeature(type=FeatureType.ACTION, shape=(action_horizon, max_action_dim)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_horizon, max_action_dim)), } # Normalize STATE and ACTION with min_max (SO100-like default) @@ -120,7 +124,7 @@ def make_groot_pre_post_processors( # Determine env action dimension from config (simple, object-like PolicyFeature) try: - env_action_dim = int(config.output_features["action"].shape[0]) + env_action_dim = int(config.output_features[ACTION].shape[0]) except Exception: env_action_dim = 0 @@ -268,9 +272,9 @@ class GrootPackInputsStep(ProcessorStep): return torch.where(mask, mapped, torch.zeros_like(mapped)) # 1) Video (B, T=1, V, H, W, C) uint8 - img_keys = sorted([k for k in obs if k.startswith("observation.images.")]) - if not img_keys and "observation.image" in obs: - img_keys = ["observation.image"] + img_keys = sorted([k for k in obs if k.startswith(OBS_IMAGES)]) + if not img_keys and OBS_IMAGE in obs: + img_keys = [OBS_IMAGE] if img_keys: cams = [_to_uint8_np_bhwc(obs[k]) for k in img_keys] video = np.stack(cams, axis=1) # (B, V, H, W, C) @@ -294,14 +298,14 @@ class GrootPackInputsStep(ProcessorStep): comp["language"] = lang # 3) State/state_mask -> (B, 1, max_state_dim) - if "observation.state" in obs: - state = obs["observation.state"] # (B, D) + if OBS_STATE in obs: + state = obs[OBS_STATE] # (B, D) if state.dim() != 2: raise ValueError(f"state must be (B, D), got {tuple(state.shape)}") bsz, d = state.shape # Normalize BEFORE padding if self.normalize_min_max: - state = _min_max_norm(state, "observation.state") + state = _min_max_norm(state, OBS_STATE) state = state.unsqueeze(1) # (B, 1, D) if d > self.max_state_dim: state = state[:, :, : self.max_state_dim] @@ -320,11 +324,11 @@ class GrootPackInputsStep(ProcessorStep): # Normalize BEFORE temporal expansion/padding if self.normalize_min_max: if action.dim() == 2: - action = _min_max_norm(action, "action") + action = _min_max_norm(action, ACTION) elif action.dim() == 3: b, t, d = action.shape flat = action.reshape(b * t, d) - flat = _min_max_norm(flat, "action") + flat = _min_max_norm(flat, ACTION) action = flat.view(b, t, d) if action.dim() == 2: action = action.unsqueeze(1).repeat(1, self.action_horizon, 1) @@ -590,7 +594,7 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep): # forward: y = 2 * (x - min) / denom - 1, with y=0 when denom==0 # inverse: x = (y+1)/2 * denom + min, and when denom==0 -> x = min if self.normalize_min_max and self.stats is not None: - stats_k = self.stats.get("action", {}) + stats_k = self.stats.get(ACTION, {}) d = action.shape[-1] min_v = torch.as_tensor( stats_k.get("min", torch.zeros(d)), dtype=action.dtype, device=action.device diff --git a/src/lerobot/policies/pi0/configuration_pi0.py b/src/lerobot/policies/pi0/configuration_pi0.py index d91145aa7..be9b4530f 100644 --- a/src/lerobot/policies/pi0/configuration_pi0.py +++ b/src/lerobot/policies/pi0/configuration_pi0.py @@ -21,7 +21,7 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig from lerobot.policies.rtc.configuration_rtc import RTCConfig -from lerobot.utils.constants import OBS_IMAGES +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE DEFAULT_IMAGE_SIZE = 224 @@ -124,19 +124,19 @@ class PI0Config(PreTrainedConfig): ) self.input_features[key] = empty_camera - if "observation.state" not in self.input_features: + if OBS_STATE not in self.input_features: state_feature = PolicyFeature( type=FeatureType.STATE, shape=(self.max_state_dim,), # Padded to max_state_dim ) - self.input_features["observation.state"] = state_feature + self.input_features[OBS_STATE] = state_feature - if "action" not in self.output_features: + if ACTION not in self.output_features: action_feature = PolicyFeature( type=FeatureType.ACTION, shape=(self.max_action_dim,), # Padded to max_action_dim ) - self.output_features["action"] = action_feature + self.output_features[ACTION] = action_feature def get_optimizer_preset(self) -> AdamWConfig: return AdamWConfig( diff --git a/src/lerobot/policies/pi05/configuration_pi05.py b/src/lerobot/policies/pi05/configuration_pi05.py index 4a0e23039..b96e6d196 100644 --- a/src/lerobot/policies/pi05/configuration_pi05.py +++ b/src/lerobot/policies/pi05/configuration_pi05.py @@ -21,6 +21,7 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE DEFAULT_IMAGE_SIZE = 224 @@ -117,26 +118,26 @@ class PI05Config(PreTrainedConfig): def validate_features(self) -> None: """Validate and set up input/output features.""" for i in range(self.empty_cameras): - key = f"observation.images.empty_camera_{i}" + key = OBS_IMAGES + f".empty_camera_{i}" empty_camera = PolicyFeature( type=FeatureType.VISUAL, shape=(3, *self.image_resolution), # Use configured image resolution ) self.input_features[key] = empty_camera - if "observation.state" not in self.input_features: + if OBS_STATE not in self.input_features: state_feature = PolicyFeature( type=FeatureType.STATE, shape=(self.max_state_dim,), # Padded to max_state_dim ) - self.input_features["observation.state"] = state_feature + self.input_features[OBS_STATE] = state_feature - if "action" not in self.output_features: + if ACTION not in self.output_features: action_feature = PolicyFeature( type=FeatureType.ACTION, shape=(self.max_action_dim,), # Padded to max_action_dim ) - self.output_features["action"] = action_feature + self.output_features[ACTION] = action_feature def get_optimizer_preset(self) -> AdamWConfig: return AdamWConfig( diff --git a/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py b/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py index 42aa4a132..96137e91f 100644 --- a/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py @@ -21,6 +21,7 @@ from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.utils.constants import ACTION, OBS_IMAGES, OBS_STATE DEFAULT_IMAGE_SIZE = 224 @@ -110,26 +111,26 @@ class PI0FastConfig(PreTrainedConfig): def validate_features(self) -> None: """Validate and set up input/output features.""" for i in range(self.empty_cameras): - key = f"observation.images.empty_camera_{i}" + key = OBS_IMAGES + f".empty_camera_{i}" empty_camera = PolicyFeature( type=FeatureType.VISUAL, shape=(3, *self.image_resolution), # Use configured image resolution ) self.input_features[key] = empty_camera - if "observation.state" not in self.input_features: + if OBS_STATE not in self.input_features: state_feature = PolicyFeature( type=FeatureType.STATE, shape=(self.max_state_dim,), # Padded to max_state_dim ) - self.input_features["observation.state"] = state_feature + self.input_features[OBS_STATE] = state_feature - if "action" not in self.output_features: + if ACTION not in self.output_features: action_feature = PolicyFeature( type=FeatureType.ACTION, shape=(self.max_action_dim,), # Padded to max_action_dim ) - self.output_features["action"] = action_feature + self.output_features[ACTION] = action_feature def get_optimizer_preset(self) -> AdamWConfig: return AdamWConfig( diff --git a/src/lerobot/policies/sarm/configuration_sarm.py b/src/lerobot/policies/sarm/configuration_sarm.py index 59cb352d5..673422fe2 100644 --- a/src/lerobot/policies/sarm/configuration_sarm.py +++ b/src/lerobot/policies/sarm/configuration_sarm.py @@ -26,6 +26,7 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig +from lerobot.utils.constants import OBS_IMAGES, OBS_STATE @PreTrainedConfig.register_subclass("sarm") @@ -86,8 +87,8 @@ class SARMConfig(PreTrainedConfig): pretrained_model_path: str | None = None device: str | None = None - image_key: str = "observation.images.top" # Key for image used from the dataset - state_key: str = "observation.state" + image_key: str = OBS_IMAGES + ".top" # Key for image used from the dataset + state_key: str = OBS_STATE # Populated by the processor (video_features, state_features, text_features) input_features: dict = field(default_factory=lambda: {}) diff --git a/src/lerobot/policies/sarm/modeling_sarm.py b/src/lerobot/policies/sarm/modeling_sarm.py index a88b2ad64..6051d90f8 100644 --- a/src/lerobot/policies/sarm/modeling_sarm.py +++ b/src/lerobot/policies/sarm/modeling_sarm.py @@ -40,6 +40,7 @@ from lerobot.policies.sarm.sarm_utils import ( normalize_stage_tau, pad_state_to_max_dim, ) +from lerobot.utils.constants import OBS_STR class StageTransformer(nn.Module): @@ -721,7 +722,7 @@ class SARMRewardModel(PreTrainedPolicy): Returns: Tuple of (total_loss, output_dict with loss components) """ - observation = batch.get("observation", batch) + observation = batch.get(OBS_STR, batch) # Extract features video_features = observation["video_features"].to(self.device) diff --git a/src/lerobot/policies/utils.py b/src/lerobot/policies/utils.py index bfbe2bf1d..1a14b2925 100644 --- a/src/lerobot/policies/utils.py +++ b/src/lerobot/policies/utils.py @@ -16,7 +16,6 @@ import logging from collections import deque -from typing import Any import numpy as np import torch @@ -140,7 +139,7 @@ def prepare_observation_for_inference( def build_inference_frame( - observation: dict[str, Any], + observation: RobotObservation, device: torch.device, ds_features: dict[str, dict], task: str | None = None, diff --git a/src/lerobot/policies/wall_x/configuration_wall_x.py b/src/lerobot/policies/wall_x/configuration_wall_x.py index 0d10a8f98..3962b56f6 100644 --- a/src/lerobot/policies/wall_x/configuration_wall_x.py +++ b/src/lerobot/policies/wall_x/configuration_wall_x.py @@ -18,6 +18,7 @@ from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig +from lerobot.utils.constants import ACTION, OBS_STATE @PreTrainedConfig.register_subclass("wall_x") @@ -105,14 +106,14 @@ class WallXConfig(PreTrainedConfig): "No features of type FeatureType.VISUAL found in input_features." ) - if "observation.state" not in self.input_features: + if OBS_STATE not in self.input_features: state_feature = PolicyFeature( type=FeatureType.STATE, shape=(self.max_state_dim,), # Padded to max_state_dim ) - self.input_features["observation.state"] = state_feature + self.input_features[OBS_STATE] = state_feature else: - state_shape = self.input_features["observation.state"].shape + state_shape = self.input_features[OBS_STATE].shape state_dim = state_shape[0] if state_shape else 0 if state_dim > self.max_state_dim: raise ValueError( @@ -120,14 +121,14 @@ class WallXConfig(PreTrainedConfig): f"Either reduce state dimension or increase max_state_dim in config." ) - if "action" not in self.output_features: + if ACTION not in self.output_features: action_feature = PolicyFeature( type=FeatureType.ACTION, shape=(self.max_action_dim,), # Padded to max_action_dim ) - self.output_features["action"] = action_feature + self.output_features[ACTION] = action_feature else: - action_shape = self.output_features["action"].shape + action_shape = self.output_features[ACTION].shape action_dim = action_shape[0] if action_shape else 0 if action_dim > self.max_action_dim: raise ValueError( diff --git a/src/lerobot/policies/wall_x/modeling_wall_x.py b/src/lerobot/policies/wall_x/modeling_wall_x.py index 94ee7897e..ef99bad89 100644 --- a/src/lerobot/policies/wall_x/modeling_wall_x.py +++ b/src/lerobot/policies/wall_x/modeling_wall_x.py @@ -1861,7 +1861,7 @@ class WallXPolicy(PreTrainedPolicy): dim=-1, ) else: - action_dim = self.config.output_features["action"].shape[0] + action_dim = self.config.output_features[ACTION].shape[0] dof_mask = torch.cat( [ torch.ones( @@ -1977,7 +1977,7 @@ class WallXPolicy(PreTrainedPolicy): elif self.config.prediction_mode == "fast": output = self.model( **batch, - action_dim=self.config.output_features["action"].shape[0], + action_dim=self.config.output_features[ACTION].shape[0], pred_horizon=self.config.chunk_size, mode="predict", predict_mode="fast", @@ -1989,7 +1989,7 @@ class WallXPolicy(PreTrainedPolicy): actions = output["predict_action"] # Unpad actions to actual action dimension - action_dim = self.config.output_features["action"].shape[0] + action_dim = self.config.output_features[ACTION].shape[0] actions = actions[:, :, :action_dim] return actions diff --git a/src/lerobot/policies/xvla/processor_xvla.py b/src/lerobot/policies/xvla/processor_xvla.py index 7f7297b9a..c4e3f2d6f 100644 --- a/src/lerobot/policies/xvla/processor_xvla.py +++ b/src/lerobot/policies/xvla/processor_xvla.py @@ -41,6 +41,7 @@ from lerobot.processor.converters import policy_action_to_transition, transition from lerobot.processor.core import EnvTransition, TransitionKey from lerobot.utils.constants import ( OBS_IMAGES, + OBS_PREFIX, OBS_STATE, POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME, @@ -137,8 +138,9 @@ class LiberoProcessorStep(ObservationProcessorStep): processed_obs[key] = img # Process robot_state into a flat state vector - if "observation.robot_state" in processed_obs: - robot_state = processed_obs.pop("observation.robot_state") + robot_state_str = OBS_PREFIX + "robot_state" + if robot_state_str in processed_obs: + robot_state = processed_obs.pop(robot_state_str) # Extract components eef_pos = robot_state["eef"]["pos"] # (B, 3,) @@ -174,8 +176,8 @@ class LiberoProcessorStep(ObservationProcessorStep): state_feats = {} # add our new flattened state - state_feats["observation.state"] = PolicyFeature( - key="observation.state", + state_feats[OBS_STATE] = PolicyFeature( + key=OBS_STATE, shape=(20,), dtype="float32", ) @@ -247,7 +249,7 @@ class XVLAImageScaleProcessorStep(ProcessorStep): keys_to_scale = self.image_keys if keys_to_scale is None: # Auto-detect image keys - keys_to_scale = [k for k in obs if k.startswith("observation.images.")] + keys_to_scale = [k for k in obs if k.startswith(OBS_IMAGES)] # Scale each image for key in keys_to_scale: @@ -303,7 +305,7 @@ class XVLAImageToFloatProcessorStep(ProcessorStep): keys_to_convert = self.image_keys if keys_to_convert is None: # Auto-detect image keys - keys_to_convert = [k for k in obs if k.startswith("observation.images.")] + keys_to_convert = [k for k in obs if k.startswith(OBS_IMAGES)] # Convert each image for key in keys_to_convert: @@ -376,7 +378,7 @@ class XVLAImageNetNormalizeProcessorStep(ProcessorStep): keys_to_normalize = self.image_keys if keys_to_normalize is None: # Auto-detect image keys - keys_to_normalize = [k for k in obs if k.startswith("observation.images.")] + keys_to_normalize = [k for k in obs if k.startswith(OBS_IMAGES)] # Normalize each image for key in keys_to_normalize: diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 676ba29ee..164f7da03 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -49,7 +49,6 @@ from .hil_processor import ( RewardClassifierProcessorStep, TimeLimitProcessorStep, ) -from .joint_observations_processor import JointVelocityProcessorStep, MotorCurrentProcessorStep from .normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep, hotswap_stats from .observation_processor import VanillaObservationProcessorStep from .pipeline import ( @@ -94,14 +93,12 @@ __all__ = [ "ImageCropResizeProcessorStep", "InfoProcessorStep", "InterventionActionProcessorStep", - "JointVelocityProcessorStep", "make_default_processors", "make_default_teleop_action_processor", "make_default_robot_action_processor", "make_default_robot_observation_processor", "MapDeltaActionToRobotActionStep", "MapTensorToDeltaActionDictStep", - "MotorCurrentProcessorStep", "NormalizerProcessorStep", "Numpy2TorchActionProcessorStep", "ObservationProcessorStep", diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 126be0e36..4f9485fee 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -23,7 +23,7 @@ from typing import Any import numpy as np import torch -from lerobot.utils.constants import ACTION, DONE, OBS_PREFIX, REWARD, TRUNCATED +from lerobot.utils.constants import ACTION, DONE, INFO, OBS_PREFIX, REWARD, TRUNCATED from .core import EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey @@ -176,7 +176,7 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]: def create_transition( - observation: dict[str, Any] | None = None, + observation: RobotObservation | None = None, action: PolicyAction | RobotAction | None = None, reward: float = 0.0, done: bool = False, @@ -384,7 +384,7 @@ def transition_to_batch(transition: EnvTransition) -> dict[str, Any]: REWARD: transition.get(TransitionKey.REWARD, 0.0), DONE: transition.get(TransitionKey.DONE, False), TRUNCATED: transition.get(TransitionKey.TRUNCATED, False), - "info": transition.get(TransitionKey.INFO, {}), + INFO: transition.get(TransitionKey.INFO, {}), } # Add complementary data. diff --git a/src/lerobot/processor/core.py b/src/lerobot/processor/core.py index 679ba8c54..0b293c9b0 100644 --- a/src/lerobot/processor/core.py +++ b/src/lerobot/processor/core.py @@ -45,7 +45,7 @@ RobotObservation: TypeAlias = dict[str, Any] EnvTransition = TypedDict( "EnvTransition", { - TransitionKey.OBSERVATION.value: dict[str, Any] | None, + TransitionKey.OBSERVATION.value: RobotObservation | None, TransitionKey.ACTION.value: PolicyAction | RobotAction | EnvAction | None, TransitionKey.REWARD.value: float | torch.Tensor | None, TransitionKey.DONE.value: bool | torch.Tensor | None, diff --git a/src/lerobot/processor/env_processor.py b/src/lerobot/processor/env_processor.py index a5210af30..8d42bfdb7 100644 --- a/src/lerobot/processor/env_processor.py +++ b/src/lerobot/processor/env_processor.py @@ -18,7 +18,7 @@ from dataclasses import dataclass import torch from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR +from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR from .pipeline import ObservationProcessorStep, ProcessorStepRegistry @@ -60,8 +60,9 @@ class LiberoProcessorStep(ObservationProcessorStep): processed_obs[key] = img # Process robot_state into a flat state vector - if "observation.robot_state" in processed_obs: - robot_state = processed_obs.pop("observation.robot_state") + observation_robot_state_str = OBS_PREFIX + "robot_state" + if observation_robot_state_str in processed_obs: + robot_state = processed_obs.pop(observation_robot_state_str) # Extract components eef_pos = robot_state["eef"]["pos"] # (B, 3,) @@ -98,8 +99,8 @@ class LiberoProcessorStep(ObservationProcessorStep): state_feats = {} # add our new flattened state - state_feats["observation.state"] = PolicyFeature( - key="observation.state", + state_feats[OBS_STATE] = PolicyFeature( + key=OBS_STATE, shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)] dtype="float32", description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."), diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 368c9b270..4769b91ac 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -30,7 +30,7 @@ from lerobot.utils.constants import ACTION from .converters import from_tensor_to_numpy, to_tensor from .core import EnvTransition, PolicyAction, TransitionKey -from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry +from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry, RobotObservation @dataclass @@ -239,7 +239,7 @@ class _NormalizationMixin: config["normalize_observation_keys"] = sorted(self.normalize_observation_keys) return config - def _normalize_observation(self, observation: dict[str, Any], inverse: bool) -> dict[str, Tensor]: + def _normalize_observation(self, observation: RobotObservation, inverse: bool) -> dict[str, Tensor]: """ Applies (un)normalization to all relevant features in an observation dictionary. diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index e14d8b0b9..97ec716ff 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -49,7 +49,7 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature from lerobot.utils.hub import HubMixin from .converters import batch_to_transition, create_transition, transition_to_batch -from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, TransitionKey +from .core import EnvAction, EnvTransition, PolicyAction, RobotAction, RobotObservation, TransitionKey # Generic type variables for pipeline input and output. TInput = TypeVar("TInput") @@ -1337,7 +1337,7 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]): return features # Convenience methods for processing individual parts of a transition. - def process_observation(self, observation: dict[str, Any]) -> dict[str, Any]: + def process_observation(self, observation: RobotObservation) -> RobotObservation: """Processes only the observation part of a transition through the pipeline. Args: @@ -1440,7 +1440,7 @@ class ObservationProcessorStep(ProcessorStep, ABC): """An abstract `ProcessorStep` that specifically targets the observation in a transition.""" @abstractmethod - def observation(self, observation: dict[str, Any]) -> dict[str, Any]: + def observation(self, observation: RobotObservation) -> RobotObservation: """Processes an observation dictionary. Subclasses must implement this method. Args: diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index 93e0395b9..5cd1bebb0 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -38,7 +38,7 @@ from lerobot.utils.constants import ( ) from lerobot.utils.import_utils import _transformers_available -from .core import EnvTransition, TransitionKey +from .core import EnvTransition, RobotObservation, TransitionKey from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry # Conditional import for type checking and lazy loading @@ -139,7 +139,7 @@ class TokenizerProcessorStep(ObservationProcessorStep): return None - def observation(self, observation: dict[str, Any]) -> dict[str, Any]: + def observation(self, observation: RobotObservation) -> RobotObservation: """ Tokenizes the task description and adds it to the observation dictionary. diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index 604adb931..3d58ae18f 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -38,13 +38,12 @@ from lerobot.processor import ( GripperPenaltyProcessorStep, ImageCropResizeProcessorStep, InterventionActionProcessorStep, - JointVelocityProcessorStep, MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep, - MotorCurrentProcessorStep, Numpy2TorchActionProcessorStep, RewardClassifierProcessorStep, RobotActionToPolicyActionProcessorStep, + RobotObservation, TimeLimitProcessorStep, Torch2NumpyActionProcessorStep, TransitionKey, @@ -77,6 +76,8 @@ from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say +from .joint_observations_processor import JointVelocityProcessorStep, MotorCurrentProcessorStep + logging.basicConfig(level=logging.INFO) @@ -163,7 +164,7 @@ class RobotEnv(gym.Env): self._setup_spaces() - def _get_observation(self) -> dict[str, Any]: + def _get_observation(self) -> RobotObservation: """Get current robot observation including joint positions and camera images.""" obs_dict = self.robot.get_observation() raw_joint_joint_position = {f"{name}.pos": obs_dict[f"{name}.pos"] for name in self._joint_names} @@ -220,7 +221,7 @@ class RobotEnv(gym.Env): def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None - ) -> tuple[dict[str, Any], dict[str, Any]]: + ) -> tuple[RobotObservation, dict[str, Any]]: """Reset environment to initial state. Args: @@ -249,7 +250,7 @@ class RobotEnv(gym.Env): self._raw_joint_positions = {f"{key}.pos": obs[f"{key}.pos"] for key in self._joint_names} return obs, {TeleopEvents.IS_INTERVENTION: False} - def step(self, action) -> tuple[dict[str, np.ndarray], float, bool, bool, dict[str, Any]]: + def step(self, action) -> tuple[RobotObservation, float, bool, bool, dict[str, Any]]: """Execute one environment step with given action.""" joint_targets_dict = {f"{key}.pos": action[i] for i, key in enumerate(self.robot.bus.motors.keys())} diff --git a/src/lerobot/processor/joint_observations_processor.py b/src/lerobot/rl/joint_observations_processor.py similarity index 100% rename from src/lerobot/processor/joint_observations_processor.py rename to src/lerobot/rl/joint_observations_processor.py diff --git a/src/lerobot/robots/bi_so_follower/bi_so_follower.py b/src/lerobot/robots/bi_so_follower/bi_so_follower.py index fa81e7d09..09f849772 100644 --- a/src/lerobot/robots/bi_so_follower/bi_so_follower.py +++ b/src/lerobot/robots/bi_so_follower/bi_so_follower.py @@ -16,8 +16,8 @@ import logging from functools import cached_property -from typing import Any +from lerobot.processor import RobotAction, RobotObservation from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig from ..robot import Robot @@ -116,7 +116,7 @@ class BiSOFollower(Robot): self.left_arm.setup_motors() self.right_arm.setup_motors() - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: obs_dict = {} # Add "left_" prefix @@ -129,7 +129,7 @@ class BiSOFollower(Robot): return obs_dict - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: # Remove "left_" prefix left_action = { key.removeprefix("left_"): value for key, value in action.items() if key.startswith("left_") diff --git a/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py b/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py index 48cb09215..784a95577 100644 --- a/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py +++ b/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py @@ -18,12 +18,12 @@ import base64 import logging from functools import cached_property -from typing import Any import cv2 import numpy as np import requests +from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot @@ -197,11 +197,11 @@ class EarthRoverMiniPlus(Robot): ACTION_ANGULAR_VEL: float, } - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: """Get current robot observation from SDK. Returns: - dict: Observation containing: + RobotObservation: Observation containing: - front: Front camera image (480, 640, 3) in RGB format - rear: Rear camera image (480, 640, 3) in RGB format - linear.vel: Current speed (0-1, SDK reports only positive speeds) @@ -255,7 +255,7 @@ class EarthRoverMiniPlus(Robot): return observation - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: """Send action to robot via SDK. Args: @@ -264,8 +264,7 @@ class EarthRoverMiniPlus(Robot): - angular.vel: Target angular velocity (-1 to 1) Returns: - dict: The action that was sent (matches action_features keys) - + RobotAction: The action that was sent (matches action_features keys) Raises: DeviceNotConnectedError: If robot is not connected diff --git a/src/lerobot/robots/hope_jr/hope_jr_arm.py b/src/lerobot/robots/hope_jr/hope_jr_arm.py index 220a29f8c..4be8a0b17 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_arm.py +++ b/src/lerobot/robots/hope_jr/hope_jr_arm.py @@ -17,7 +17,6 @@ import logging import time from functools import cached_property -from typing import Any from lerobot.cameras.utils import make_cameras_from_configs from lerobot.motors import Motor, MotorNormMode @@ -25,6 +24,7 @@ from lerobot.motors.calibration_gui import RangeFinderGUI from lerobot.motors.feetech import ( FeetechMotorsBus, ) +from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot @@ -128,7 +128,7 @@ class HopeJrArm(Robot): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -149,7 +149,7 @@ class HopeJrArm(Robot): return obs_dict - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") diff --git a/src/lerobot/robots/hope_jr/hope_jr_hand.py b/src/lerobot/robots/hope_jr/hope_jr_hand.py index 9e960642b..73fb4464f 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_hand.py +++ b/src/lerobot/robots/hope_jr/hope_jr_hand.py @@ -17,7 +17,6 @@ import logging import time from functools import cached_property -from typing import Any from lerobot.cameras.utils import make_cameras_from_configs from lerobot.motors import Motor, MotorNormMode @@ -25,6 +24,7 @@ from lerobot.motors.calibration_gui import RangeFinderGUI from lerobot.motors.feetech import ( FeetechMotorsBus, ) +from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot @@ -159,7 +159,7 @@ class HopeJrHand(Robot): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -181,7 +181,7 @@ class HopeJrHand(Robot): return obs_dict - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") diff --git a/src/lerobot/robots/koch_follower/koch_follower.py b/src/lerobot/robots/koch_follower/koch_follower.py index 41a57828b..a1d001ba8 100644 --- a/src/lerobot/robots/koch_follower/koch_follower.py +++ b/src/lerobot/robots/koch_follower/koch_follower.py @@ -17,7 +17,6 @@ import logging import time from functools import cached_property -from typing import Any from lerobot.cameras.utils import make_cameras_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode @@ -25,6 +24,7 @@ from lerobot.motors.dynamixel import ( DynamixelMotorsBus, OperatingMode, ) +from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot @@ -182,7 +182,7 @@ class KochFollower(Robot): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -202,7 +202,7 @@ class KochFollower(Robot): return obs_dict - def send_action(self, action: dict[str, float]) -> dict[str, float]: + def send_action(self, action: RobotAction) -> RobotAction: """Command arm to move to a target joint configuration. The relative action magnitude may be clipped depending on the configuration parameter @@ -210,10 +210,10 @@ class KochFollower(Robot): Thus, this function always returns the action actually sent. Args: - action (dict[str, float]): The goal positions for the motors. + action (RobotAction): The goal positions for the motors. Returns: - dict[str, float]: The action sent to the motors, potentially clipped. + RobotAction: The action sent to the motors, potentially clipped. """ if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") diff --git a/src/lerobot/robots/lekiwi/lekiwi.py b/src/lerobot/robots/lekiwi/lekiwi.py index 86fe017d6..c84e81001 100644 --- a/src/lerobot/robots/lekiwi/lekiwi.py +++ b/src/lerobot/robots/lekiwi/lekiwi.py @@ -28,6 +28,7 @@ from lerobot.motors.feetech import ( FeetechMotorsBus, OperatingMode, ) +from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot @@ -338,7 +339,7 @@ class LeKiwi(Robot): "theta.vel": theta, } # m/s and deg/s - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -369,7 +370,7 @@ class LeKiwi(Robot): return obs_dict - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: """Command lekiwi to move to a target joint configuration. The relative action magnitude may be clipped depending on the configuration parameter @@ -380,7 +381,7 @@ class LeKiwi(Robot): RobotDeviceNotConnectedError: if robot is not connected. Returns: - np.ndarray: the action sent to the motors, potentially clipped. + RobotAction: the action sent to the motors, potentially clipped. """ if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") diff --git a/src/lerobot/robots/lekiwi/lekiwi_client.py b/src/lerobot/robots/lekiwi/lekiwi_client.py index 19744e244..bb865dc10 100644 --- a/src/lerobot/robots/lekiwi/lekiwi_client.py +++ b/src/lerobot/robots/lekiwi/lekiwi_client.py @@ -18,11 +18,11 @@ import base64 import json import logging from functools import cached_property -from typing import Any import cv2 import numpy as np +from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_STATE from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError @@ -172,7 +172,7 @@ class LeKiwiClient(Robot): return last_msg - def _parse_observation_json(self, obs_string: str) -> dict[str, Any] | None: + def _parse_observation_json(self, obs_string: str) -> RobotObservation | None: """Parses the JSON observation string.""" try: return json.loads(obs_string) @@ -196,15 +196,15 @@ class LeKiwiClient(Robot): return None def _remote_state_from_obs( - self, observation: dict[str, Any] - ) -> tuple[dict[str, np.ndarray], dict[str, Any]]: + self, observation: RobotObservation + ) -> tuple[dict[str, np.ndarray], RobotObservation]: """Extracts frames, and state from the parsed observation.""" flat_state = {key: observation.get(key, 0.0) for key in self._state_order} state_vec = np.array([flat_state[key] for key in self._state_order], dtype=np.float32) - obs_dict: dict[str, Any] = {**flat_state, OBS_STATE: state_vec} + obs_dict: RobotObservation = {**flat_state, OBS_STATE: state_vec} # Decode images current_frames: dict[str, np.ndarray] = {} @@ -217,7 +217,7 @@ class LeKiwiClient(Robot): return current_frames, obs_dict - def _get_data(self) -> tuple[dict[str, np.ndarray], dict[str, Any], dict[str, Any]]: + def _get_data(self) -> tuple[dict[str, np.ndarray], RobotObservation]: """ Polls the video socket for the latest observation data. @@ -252,7 +252,7 @@ class LeKiwiClient(Robot): return new_frames, new_state - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: """ Capture observations from the remote robot: current follower arm positions, present wheel speeds (converted to body-frame velocities: x, y, theta), @@ -307,12 +307,11 @@ class LeKiwiClient(Robot): def configure(self): pass - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: """Command lekiwi to move to a target joint configuration. Translates to motor space + sends over ZMQ Args: - action (np.ndarray): array containing the goal positions for the motors. - + action (RobotAction): array containing the goal positions for the motors. Raises: RobotDeviceNotConnectedError: if robot is not connected. diff --git a/src/lerobot/robots/omx_follower/omx_follower.py b/src/lerobot/robots/omx_follower/omx_follower.py index 2dd851377..14668b3a7 100644 --- a/src/lerobot/robots/omx_follower/omx_follower.py +++ b/src/lerobot/robots/omx_follower/omx_follower.py @@ -17,7 +17,6 @@ import logging import time from functools import cached_property -from typing import Any from lerobot.cameras.utils import make_cameras_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode @@ -26,6 +25,7 @@ from lerobot.motors.dynamixel import ( DynamixelMotorsBus, OperatingMode, ) +from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot @@ -165,7 +165,7 @@ class OmxFollower(Robot): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -185,7 +185,7 @@ class OmxFollower(Robot): return obs_dict - def send_action(self, action: dict[str, float]) -> dict[str, float]: + def send_action(self, action: RobotAction) -> RobotAction: """Command arm to move to a target joint configuration. The relative action magnitude may be clipped depending on the configuration parameter @@ -193,10 +193,10 @@ class OmxFollower(Robot): Thus, this function always returns the action actually sent. Args: - action (dict[str, float]): The goal positions for the motors. + action (RobotAction): The goal positions for the motors. Returns: - dict[str, float]: The action sent to the motors, potentially clipped. + RobotAction: The action sent to the motors, potentially clipped. """ if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") diff --git a/src/lerobot/robots/reachy2/robot_reachy2.py b/src/lerobot/robots/reachy2/robot_reachy2.py index 74742ee8d..6f4eef56c 100644 --- a/src/lerobot/robots/reachy2/robot_reachy2.py +++ b/src/lerobot/robots/reachy2/robot_reachy2.py @@ -18,9 +18,8 @@ from __future__ import annotations import time from typing import TYPE_CHECKING, Any -import numpy as np - from lerobot.cameras.utils import make_cameras_from_configs +from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.import_utils import _reachy2_sdk_available from ..robot import Robot @@ -171,8 +170,8 @@ class Reachy2Robot(Robot): else: return {} - def get_observation(self) -> dict[str, np.ndarray]: - obs_dict: dict[str, Any] = {} + def get_observation(self) -> RobotObservation: + obs_dict: RobotObservation = {} # Read Reachy 2 state before_read_t = time.perf_counter() @@ -185,7 +184,7 @@ class Reachy2Robot(Robot): return obs_dict - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: if self.reachy is not None: if not self.is_connected: raise ConnectionError() diff --git a/src/lerobot/robots/robot.py b/src/lerobot/robots/robot.py index 5e88b915b..d1021daf4 100644 --- a/src/lerobot/robots/robot.py +++ b/src/lerobot/robots/robot.py @@ -15,11 +15,11 @@ import abc import builtins from pathlib import Path -from typing import Any import draccus from lerobot.motors import MotorCalibration +from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, ROBOTS from .config import RobotConfig @@ -153,28 +153,28 @@ class Robot(abc.ABC): pass @abc.abstractmethod - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: """ Retrieve the current observation from the robot. Returns: - dict[str, Any]: A flat dictionary representing the robot's current sensory state. Its structure + RobotObservation: A flat dictionary representing the robot's current sensory state. Its structure should match :pymeth:`observation_features`. """ pass @abc.abstractmethod - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: """ Send an action command to the robot. Args: - action (dict[str, Any]): Dictionary representing the desired action. Its structure should match + action (RobotAction): Dictionary representing the desired action. Its structure should match :pymeth:`action_features`. Returns: - dict[str, Any]: The action actually sent to the motors potentially clipped or modified, e.g. by + RobotAction: The action actually sent to the motors potentially clipped or modified, e.g. by safety limits on velocity. """ pass diff --git a/src/lerobot/robots/so_follower/robot_kinematic_processor.py b/src/lerobot/robots/so_follower/robot_kinematic_processor.py index 87e832db6..2aa60e12a 100644 --- a/src/lerobot/robots/so_follower/robot_kinematic_processor.py +++ b/src/lerobot/robots/so_follower/robot_kinematic_processor.py @@ -28,6 +28,7 @@ from lerobot.processor import ( ProcessorStepRegistry, RobotAction, RobotActionProcessorStep, + RobotObservation, TransitionKey, ) from lerobot.utils.rotation import Rotation @@ -438,7 +439,7 @@ class ForwardKinematicsJointsToEEObservation(ObservationProcessorStep): kinematics: RobotKinematics motor_names: list[str] - def observation(self, observation: dict[str, Any]) -> dict[str, Any]: + def observation(self, observation: RobotObservation) -> RobotObservation: return compute_forward_kinematics_joints_to_ee(observation, self.kinematics, self.motor_names) def transform_features( diff --git a/src/lerobot/robots/so_follower/so_follower.py b/src/lerobot/robots/so_follower/so_follower.py index 5e99b33a1..011a0061e 100644 --- a/src/lerobot/robots/so_follower/so_follower.py +++ b/src/lerobot/robots/so_follower/so_follower.py @@ -17,7 +17,7 @@ import logging import time from functools import cached_property -from typing import Any, TypeAlias +from typing import TypeAlias from lerobot.cameras.utils import make_cameras_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode @@ -25,6 +25,7 @@ from lerobot.motors.feetech import ( FeetechMotorsBus, OperatingMode, ) +from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..robot import Robot @@ -175,7 +176,7 @@ class SOFollower(Robot): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -195,7 +196,7 @@ class SOFollower(Robot): return obs_dict - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: """Command arm to move to a target joint configuration. The relative action magnitude may be clipped depending on the configuration parameter @@ -206,7 +207,7 @@ class SOFollower(Robot): RobotDeviceNotConnectedError: if robot is not connected. Returns: - the action sent to the motors, potentially clipped. + RobotAction: the action sent to the motors, potentially clipped. """ if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") diff --git a/src/lerobot/robots/unitree_g1/unitree_g1.py b/src/lerobot/robots/unitree_g1/unitree_g1.py index 1764f31b5..fa6e0da85 100644 --- a/src/lerobot/robots/unitree_g1/unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/unitree_g1.py @@ -26,6 +26,7 @@ import numpy as np from lerobot.cameras.utils import make_cameras_from_configs from lerobot.envs.factory import make_env +from lerobot.processor import RobotAction, RobotObservation from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex from ..robot import Robot @@ -269,7 +270,7 @@ class UnitreeG1(Robot): for cam in self._cameras.values(): cam.disconnect() - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: lowstate = self._lowstate if lowstate is None: return {} @@ -350,7 +351,7 @@ class UnitreeG1(Robot): def observation_features(self) -> dict[str, type | tuple]: return {**self._motors_ft, **self._cameras_ft} - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: for motor in G1_29_JointIndex: key = f"{motor.name}.q" if key in action: diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index 7b45e88e1..e32b80404 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -177,9 +177,9 @@ def rollout( action = policy.select_action(observation) action = postprocessor(action) - action_transition = {"action": action} + action_transition = {ACTION: action} action_transition = env_postprocessor(action_transition) - action = action_transition["action"] + action = action_transition[ACTION] # Convert to CPU / numpy. action_numpy: np.ndarray = action.to("cpu").numpy() diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index 05d4534d4..18d8863d6 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -165,7 +165,7 @@ def teleop_loop( # Process action for robot through pipeline robot_action_to_send = robot_action_processor((teleop_action, obs)) - # Send processed action to robot (robot_action_processor.to_output should return dict[str, Any]) + # Send processed action to robot (robot_action_processor.to_output should return RobotAction) _ = robot.send_action(robot_action_to_send) if display_data: diff --git a/src/lerobot/scripts/lerobot_train_tokenizer.py b/src/lerobot/scripts/lerobot_train_tokenizer.py index 296447bad..1d8f4644b 100644 --- a/src/lerobot/scripts/lerobot_train_tokenizer.py +++ b/src/lerobot/scripts/lerobot_train_tokenizer.py @@ -63,6 +63,7 @@ else: from lerobot.configs import parser from lerobot.configs.types import NormalizationMode from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.utils.constants import ACTION, OBS_STATE @dataclass @@ -86,7 +87,7 @@ class TokenizerTrainingConfig: # Whether to apply delta transform (relative actions vs absolute actions) use_delta_transform: bool = False # Dataset key for state observations (default: "observation.state") - state_key: str = "observation.state" + state_key: str = OBS_STATE # Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY) normalization_mode: str = "QUANTILES" # FAST vocabulary size (BPE vocab size) @@ -223,12 +224,10 @@ def process_episode(args): else: # if no state key, use zeros (no delta transform) state = np.zeros_like( - frame["action"].numpy() if torch.is_tensor(frame["action"]) else np.array(frame["action"]) + frame[ACTION].numpy() if torch.is_tensor(frame[ACTION]) else np.array(frame[ACTION]) ) - action = ( - frame["action"].numpy() if torch.is_tensor(frame["action"]) else np.array(frame["action"]) - ) + action = frame[ACTION].numpy() if torch.is_tensor(frame[ACTION]) else np.array(frame[ACTION]) states.append(state) actions.append(action) @@ -468,8 +467,8 @@ def train_tokenizer(cfg: TokenizerTrainingConfig): # get normalization stats from dataset norm_stats = dataset.meta.stats - if norm_stats is not None and "action" in norm_stats: - action_stats = norm_stats["action"] + if norm_stats is not None and ACTION in norm_stats: + action_stats = norm_stats[ACTION] # build encoded dimension indices encoded_dim_indices = [] diff --git a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py index c7072f4a7..4dbb49c1d 100644 --- a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py +++ b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py @@ -20,6 +20,8 @@ from typing import Any import numpy as np +from lerobot.processor import RobotAction + from ..teleoperator import Teleoperator from ..utils import TeleopEvents from .configuration_gamepad import GamepadTeleopConfig @@ -83,7 +85,7 @@ class GamepadTeleop(Teleoperator): self.gamepad = Gamepad() self.gamepad.start() - def get_action(self) -> dict[str, Any]: + def get_action(self) -> RobotAction: # Update the controller to get fresh inputs self.gamepad.update() diff --git a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py index ec8ea18f4..55c158da8 100644 --- a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py +++ b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py @@ -21,6 +21,7 @@ import time from queue import Queue from typing import Any +from lerobot.processor import RobotAction from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from ..teleoperator import Teleoperator @@ -124,7 +125,7 @@ class KeyboardTeleop(Teleoperator): def configure(self): pass - def get_action(self) -> dict[str, Any]: + def get_action(self) -> RobotAction: before_read_t = time.perf_counter() if not self.is_connected: @@ -181,7 +182,7 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop): "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2}, } - def get_action(self) -> dict[str, Any]: + def get_action(self) -> RobotAction: if not self.is_connected: raise DeviceNotConnectedError( "KeyboardTeleop is not connected. You need to run `connect()` before `get_action()`." @@ -374,12 +375,12 @@ class KeyboardRoverTeleop(KeyboardTeleop): # Only remove key if it's being released self.current_pressed.pop(key_char, None) - def get_action(self) -> dict[str, Any]: + def get_action(self) -> RobotAction: """ Get the current action based on pressed keys. Returns: - dict with 'linear.vel' and 'angular.vel' keys + RobotAction with 'linear.vel' and 'angular.vel' keys """ before_read_t = time.perf_counter() diff --git a/src/lerobot/teleoperators/teleoperator.py b/src/lerobot/teleoperators/teleoperator.py index 95020a962..cd9e3a53d 100644 --- a/src/lerobot/teleoperators/teleoperator.py +++ b/src/lerobot/teleoperators/teleoperator.py @@ -20,6 +20,7 @@ from typing import Any import draccus from lerobot.motors.motors_bus import MotorCalibration +from lerobot.processor import RobotAction from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS from .config import TeleoperatorConfig @@ -150,12 +151,12 @@ class Teleoperator(abc.ABC): pass @abc.abstractmethod - def get_action(self) -> dict[str, Any]: + def get_action(self) -> RobotAction: """ Retrieve the current action from the teleoperator. Returns: - dict[str, Any]: A flat dictionary representing the teleoperator's current actions. Its + RobotAction: A flat dictionary representing the teleoperator's current actions. Its structure should match :pymeth:`observation_features`. """ pass diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index a96e1596d..43a61b4f7 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -28,11 +28,13 @@ OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens" OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" ACTION = "action" +ACTION_PREFIX = ACTION + "." ACTION_TOKENS = ACTION + ".tokens" ACTION_TOKEN_MASK = ACTION + ".token_mask" REWARD = "next.reward" TRUNCATED = "next.truncated" DONE = "next.done" +INFO = "info" ROBOTS = "robots" TELEOPERATORS = "teleoperators" diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index 9143d0f66..31ca8d247 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -14,12 +14,13 @@ import numbers import os -from typing import Any import numpy as np import rerun as rr -from .constants import OBS_PREFIX, OBS_STR +from lerobot.processor import RobotAction, RobotObservation + +from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR def init_rerun( @@ -50,8 +51,8 @@ def _is_scalar(x): def log_rerun_data( - observation: dict[str, Any] | None = None, - action: dict[str, Any] | None = None, + observation: RobotObservation | None = None, + action: RobotAction | None = None, compress_images: bool = False, ) -> None: """ @@ -96,7 +97,7 @@ def log_rerun_data( for k, v in action.items(): if v is None: continue - key = k if str(k).startswith("action.") else f"action.{k}" + key = k if str(k).startswith(ACTION_PREFIX) else f"{ACTION}.{k}" if _is_scalar(v): rr.log(key, rr.Scalars(float(v))) diff --git a/tests/mocks/mock_robot.py b/tests/mocks/mock_robot.py index b0513fd38..d997cb6d4 100644 --- a/tests/mocks/mock_robot.py +++ b/tests/mocks/mock_robot.py @@ -17,10 +17,10 @@ import random from dataclasses import dataclass, field from functools import cached_property -from typing import Any from lerobot.cameras import CameraConfig, make_cameras_from_configs from lerobot.motors.motors_bus import Motor, MotorNormMode +from lerobot.processor import RobotAction, RobotObservation from lerobot.robots import Robot, RobotConfig from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from tests.mocks.mock_motors_bus import MockMotorsBus @@ -119,7 +119,7 @@ class MockRobot(Robot): def configure(self) -> None: pass - def get_observation(self) -> dict[str, Any]: + def get_observation(self) -> RobotObservation: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -130,7 +130,7 @@ class MockRobot(Robot): f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True) } - def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + def send_action(self, action: RobotAction) -> RobotAction: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") diff --git a/tests/mocks/mock_teleop.py b/tests/mocks/mock_teleop.py index 71b49947c..04479bad9 100644 --- a/tests/mocks/mock_teleop.py +++ b/tests/mocks/mock_teleop.py @@ -19,6 +19,7 @@ from dataclasses import dataclass from functools import cached_property from typing import Any +from lerobot.processor import RobotAction from lerobot.teleoperators import Teleoperator, TeleoperatorConfig from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError @@ -88,7 +89,7 @@ class MockTeleop(Teleoperator): def configure(self) -> None: pass - def get_action(self) -> dict[str, Any]: + def get_action(self) -> RobotAction: if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") From 1c61b43b1596178adec3d2649b170787f63f1777 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 14 Jan 2026 17:14:12 +0100 Subject: [PATCH 037/111] fix(teleop): add is_connected check to get_action (#2801) --- src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py | 4 ++++ src/lerobot/teleoperators/gamepad/teleop_gamepad.py | 5 +++++ src/lerobot/teleoperators/homunculus/homunculus_arm.py | 3 +++ src/lerobot/teleoperators/homunculus/homunculus_glove.py | 3 +++ src/lerobot/teleoperators/phone/teleop_phone.py | 6 ++++++ src/lerobot/teleoperators/so_leader/so_leader.py | 3 +++ 6 files changed, 24 insertions(+) diff --git a/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py b/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py index 45c46c100..0fdc32461 100644 --- a/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py +++ b/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py @@ -18,6 +18,7 @@ import logging from functools import cached_property from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig +from lerobot.utils.errors import DeviceNotConnectedError from ..so_leader import SOLeader from ..teleoperator import Teleoperator @@ -92,6 +93,9 @@ class BiSOLeader(Teleoperator): self.right_arm.setup_motors() def get_action(self) -> dict[str, float]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + action_dict = {} # Add "left_" prefix diff --git a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py index 4dbb49c1d..14d8e3cc1 100644 --- a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py +++ b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py @@ -21,6 +21,7 @@ from typing import Any import numpy as np from lerobot.processor import RobotAction +from lerobot.utils.errors import DeviceNotConnectedError from ..teleoperator import Teleoperator from ..utils import TeleopEvents @@ -86,6 +87,9 @@ class GamepadTeleop(Teleoperator): self.gamepad.start() def get_action(self) -> RobotAction: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + # Update the controller to get fresh inputs self.gamepad.update() @@ -158,6 +162,7 @@ class GamepadTeleop(Teleoperator): self.gamepad.stop() self.gamepad = None + @property def is_connected(self) -> bool: """Check if gamepad is connected.""" return self.gamepad is not None diff --git a/src/lerobot/teleoperators/homunculus/homunculus_arm.py b/src/lerobot/teleoperators/homunculus/homunculus_arm.py index 43116f5c0..53b4dfe68 100644 --- a/src/lerobot/teleoperators/homunculus/homunculus_arm.py +++ b/src/lerobot/teleoperators/homunculus/homunculus_arm.py @@ -300,6 +300,9 @@ class HomunculusArm(Teleoperator): logger.debug(f"Error reading frame in background thread for {self}: {e}") def get_action(self) -> dict[str, float]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + joint_positions = self._read() return {f"{joint}.pos": pos for joint, pos in joint_positions.items()} diff --git a/src/lerobot/teleoperators/homunculus/homunculus_glove.py b/src/lerobot/teleoperators/homunculus/homunculus_glove.py index fefeec1e8..7feb926ef 100644 --- a/src/lerobot/teleoperators/homunculus/homunculus_glove.py +++ b/src/lerobot/teleoperators/homunculus/homunculus_glove.py @@ -326,6 +326,9 @@ class HomunculusGlove(Teleoperator): logger.debug(f"Error reading frame in background thread for {self}: {e}") def get_action(self) -> dict[str, float]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + joint_positions = self._read() return homunculus_glove_to_hope_jr_hand( {f"{joint}.pos": pos for joint, pos in joint_positions.items()} diff --git a/src/lerobot/teleoperators/phone/teleop_phone.py b/src/lerobot/teleoperators/phone/teleop_phone.py index 91e613190..402a88d43 100644 --- a/src/lerobot/teleoperators/phone/teleop_phone.py +++ b/src/lerobot/teleoperators/phone/teleop_phone.py @@ -165,6 +165,9 @@ class IOSPhone(BasePhone, Teleoperator): return True, pos, rot, pose def get_action(self) -> dict: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + has_pose, raw_position, raw_rotation, fb_pose = self._read_current_pose() if not has_pose or not self.is_calibrated: return {} @@ -319,6 +322,9 @@ class AndroidPhone(BasePhone, Teleoperator): self._latest_message = message def get_action(self) -> dict: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + ok, raw_pos, raw_rot, pose = self._read_current_pose() if not ok or not self.is_calibrated: return {} diff --git a/src/lerobot/teleoperators/so_leader/so_leader.py b/src/lerobot/teleoperators/so_leader/so_leader.py index 760ef2eb1..d09262c74 100644 --- a/src/lerobot/teleoperators/so_leader/so_leader.py +++ b/src/lerobot/teleoperators/so_leader/so_leader.py @@ -140,6 +140,9 @@ class SOLeader(Teleoperator): print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") def get_action(self) -> dict[str, float]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + start = time.perf_counter() action = self.bus.sync_read("Present_Position") action = {f"{motor}.pos": val for motor, val in action.items()} From a17df523e03c8a6c0cb6b82765263552795257d6 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 14 Jan 2026 17:17:56 +0100 Subject: [PATCH 038/111] chore(ci): merge annoying section in PR template (#2802) * chore(ci): merge annoying section in PR template * pre-commit --- .github/PULL_REQUEST_TEMPLATE.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index ec5ac4372..43e2442d3 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -22,20 +22,21 @@ Short, imperative summary (e.g., "fix(robots): handle None in sensor parser"). S - Short, concrete bullets of the modifications (files/behaviour). - Short note if this introduces breaking changes and migration steps. -## How was this tested +## How was this tested (or how to run locally) - Tests added: list new tests or test files. - Manual checks / dataset runs performed. +- Instructions for the reviewer -## How to run locally (reviewer) +Example: -- Run the relevant tests: +- Ran the relevant tests: ```bash pytest -q tests/ -k ``` -- Run a quick example or CLI (if applicable): +- Reproduce with a quick example or CLI (if applicable): ```bash lerobot-train --some.option=true From 6797ce615e9f142c989f3898b01a651e6d1b38b1 Mon Sep 17 00:00:00 2001 From: Neko Date: Thu, 15 Jan 2026 17:51:42 +0800 Subject: [PATCH 039/111] chore(deps): bump `wandb` & `protobuf` (#2800) --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 067cb1df8..fa4b22bdf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ dependencies = [ "packaging>=24.2,<26.0", "pynput>=1.7.7,<1.9.0", "pyserial>=3.5,<4.0", - "wandb>=0.20.0,<0.22.0", # TODO: Bumb dependency (compatible with protobuf) + "wandb>=0.24.0,<0.25.0", "torch>=2.2.1,<2.8.0", # TODO: Bumb dependency "torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency @@ -97,7 +97,7 @@ dependencies = [ pygame-dep = ["pygame>=2.5.1,<2.7.0"] placo-dep = ["placo>=0.9.6,<0.10.0"] transformers-dep = ["transformers>=4.57.1,<5.0.0"] -grpcio-dep = ["grpcio==1.73.1", "protobuf==6.31.0"] # TODO: Bumb dependency (compatible with wandb) +grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"] # Motors feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"] From 5de38813d9b158cad749fcbbf30f46a08b59963c Mon Sep 17 00:00:00 2001 From: Nicolas Rabault Date: Thu, 15 Jan 2026 18:31:17 +0100 Subject: [PATCH 040/111] Add small context to the envHub doc page (#2807) * Add small context to the envHub doc page * Add the cfg: EnvConfig on the main function explaination. --- docs/source/envhub.mdx | 41 ++++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/docs/source/envhub.mdx b/docs/source/envhub.mdx index ba6464460..df103d0dd 100644 --- a/docs/source/envhub.mdx +++ b/docs/source/envhub.mdx @@ -2,14 +2,32 @@ The **EnvHub** feature allows you to load simulation environments directly from the Hugging Face Hub with a single line of code. This unlocks a powerful new model for collaboration: instead of environments being locked away inside monolithic libraries, anyone can publish custom environments and share them with the community. -## Overview +## What is EnvHub? -With EnvHub, you can: +EnvHub lets you create custom robotics simulation environments with your own robot models and scenarios, and make them easily usable by anyone through the LeRobot framework. -- Load environments from the Hub instantly -- Share your custom simulation tasks with the community -- Version control your environments using Git -- Distribute complex physics simulations without packaging hassles +EnvHub packages are stored on the Hugging Face Hub, and can be seamlessly pulled and used in your AI robotics projects through LeRobot with a single line of code. + +Thanks to EnvHub, you can: + +1. **Create and publish environments** to the Hugging Face Hub as Git repositories, and distribute complex physics simulations without packaging hassles +2. **Load environments** dynamically, without installing them as packages +3. **Version and track** environment changes using Git semantics +4. **Discover** new simulation tasks shared by the community + +This design means you can go from discovering an interesting environment on the Hub to running experiments in seconds, or create your own custom robot and environment without worrying about dependency conflicts or complex installation procedures. + +When you create an EnvHub package, you can build anything you want inside it and use any simulation tool you like: this is your own space to play with. The only requirement is that the package contains an `env.py` file that defines the environment and allows LeRobot to load and use your EnvHub package. + +This `env.py` file needs to expose a small API so LeRobot can load and run it. In particular, you must provide a `make_env(n_envs: int = 1, use_async_envs: bool = False)` or `make_env(n_envs: int = 1, use_async_envs: bool = False, cfg: EnvConfig)` function, which is the main entry point for LeRobot. It should return one of: + +- A `gym.vector.VectorEnv` (most common) +- A single `gym.Env` (will be automatically wrapped) +- A dict mapping `{suite_name: {task_id: VectorEnv}}` (for multi-task benchmarks) + +You can also pass an `EnvConfig` object to `make_env` to configure the environment (e.g. the number of environments, task, camera name, initial states, control mode, episode length, etc.). + +Finally, your environment must implement the standard `gym.vector.VectorEnv` interface so it works with LeRobot, including methods like `reset` and `step`. ## Quick Start @@ -29,17 +47,6 @@ env = make_env("lerobot/cartpole-env", trust_remote_code=True) hash for reproducibility and security. -## What is EnvHub? - -EnvHub is a framework that allows researchers and developers to: - -1. **Publish environments** to the Hugging Face Hub as Git repositories -2. **Load environments** dynamically without installing them as packages -3. **Version and track** environment changes using Git semantics -4. **Discover** new simulation tasks shared by the community - -This design means you can go from discovering an interesting environment on the Hub to running experiments in seconds, without worrying about dependency conflicts or complex installation procedures. - ## Repository Structure To make your environment loadable from the Hub, your repository must contain at minimum: From 76d6b71b0ab9fc4f5aaec7763a2fbdfafd7711eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=2E/c=C2=B2?= Date: Thu, 15 Jan 2026 20:39:58 -0500 Subject: [PATCH 041/111] Correct Frodobots Earth Rover SDK link and add setup instructions (#2671) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix SDK link and enhance setup instructions Updated the Frodobots SDK link and added credential setup instructions. Signed-off-by: ./c² * Update docs/source/earthrover_mini_plus.mdx Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: ./c² * Update docs/source/earthrover_mini_plus.mdx Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Steven Palma --------- Signed-off-by: ./c² Signed-off-by: Steven Palma Co-authored-by: Steven Palma --- docs/source/earthrover_mini_plus.mdx | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/docs/source/earthrover_mini_plus.mdx b/docs/source/earthrover_mini_plus.mdx index 7e27eb93e..e3ffa6b32 100644 --- a/docs/source/earthrover_mini_plus.mdx +++ b/docs/source/earthrover_mini_plus.mdx @@ -12,23 +12,42 @@ The EarthRover Mini Plus is a fully open source mobile robot that connects throu ### Setting Up the Frodobots SDK -The robot needs the [Frodobots SDK](https://github.com/Frodobots/earth-rovers-sdk) running on your computer. Here's how: +The robot needs the [Frodobots SDK](https://github.com/frodobots-org/earth-rovers-sdk) running on your computer. Here's how: 1. Download and install the SDK: ```bash -git clone https://github.com/Frodobots/earth-rovers-sdk.git +git clone https://github.com/frodobots-org/earth-rovers-sdk.git cd earth-rovers-sdk pip install -r requirements.txt ``` -2. Start the SDK: +2. Save Credentials: + +Write your .env variables with the SDK API key and bot name provided by the Frodobots team. + +```bash +SDK_API_TOKEN=your_sdk_api_token_here +BOT_SLUG=your_bot_slug_here +CHROME_EXECUTABLE_PATH=/path/to/chrome_or_chromium +# Default value is MAP_ZOOM_LEVEL=18 https://wiki.openstreetmap.org/wiki/Zoom_levels +MAP_ZOOM_LEVEL=18 +MISSION_SLUG=your_mission_slug_here +# Image quality between 0.1 and 1.0 (default: 0.8) +# Recommended: 0.8 for better performance +IMAGE_QUALITY=0.8 +# Image format: jpeg, png or webp (default: png) +# Recommended: jpeg for better performance and lower bandwidth usage +IMAGE_FORMAT=jpeg +``` + +3. Start the SDK: ```bash hypercorn main:app --reload ``` -3. Open your web browser and go to `http://localhost:8000`, then click "Join" +4. Open your web browser and go to `http://localhost:8000`, then click "Join" The SDK gives you: From b825880c4064a2acb032f509bf3184fa5fba379a Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 16 Jan 2026 14:38:42 +0100 Subject: [PATCH 042/111] chore: add security policy (#2809) * chore: add security policy * pre-commit style --- SECURITY.md | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) create mode 100644 SECURITY.md diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..cf58f6cdb --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,48 @@ +# Security Policy + +## Project Status & Philosophy + +`lerobot` has so far been primarily a research and prototyping tool, which is why deployment security hasn’t been a strong focus until now. As `lerobot` continues to be adopted and deployed in production, we are paying much closer attention to these kinds of issues. + +Fortunately, being an open-source project, the community can also help by reporting and fixing vulnerabilities. We appreciate your efforts to responsibly disclose your findings and will make every effort to acknowledge your contributions. + +## Reporting a Vulnerability + +To report a security issue, please use the GitHub Security Advisory ["Report a Vulnerability"](https://github.com/huggingface/lerobot/security/advisories/new) tab. + +The `lerobot` team will send a response indicating the next steps in handling your report. After the initial reply to your report, the security team will keep you informed of the progress towards a fix and full announcement, and may ask for additional information or guidance. + +#### Hugging Face Security Team + +Since this project is part of the Hugging Face ecosystem, feel free to submit vulnerability reports directly to: **[security@huggingface.co](mailto:security@huggingface.co)**. Someone from the HF security team will review the report and recommend next steps. + +#### Open Source Disclosures + +If reporting a vulnerability specific to the open-source codebase (and not the underlying Hub infrastructure), you may also use [Huntr](https://huntr.com), a vulnerability disclosure program for open source software. + +## Supported Versions + +Currently, we treat `lerobot` as a rolling release. We prioritize security updates for the latest available version (`main` branch). + +| Version | Supported | +| -------- | --------- | +| Latest | ✅ | +| < Latest | ❌ | + +## Secure Usage Guidelines + +`lerobot` is tightly coupled to the Hugging Face Hub for sharing data and pretrained policies. When downloading artifacts uploaded by others, you expose yourself to risks. Please read below for recommendations to keep your runtime and robot environment safe. + +### Remote Artefacts (Weights & Policies) + +Models and policies uploaded to the Hugging Face Hub come in different formats. We heavily recommend uploading and downloading models in the [`safetensors`](https://github.com/huggingface/safetensors) format. + +`safetensors` was developed specifically to prevent arbitrary code execution on your system, which is critical when running software on physical hardware/robots. + +To avoid loading models from unsafe formats (e.g., `pickle`), you should ensure you are prioritizing `safetensors` files. + +### Remote Code + +Some models or environments on the Hub may require `trust_remote_code=True` to run custom architecture code. + +Please **always** verify the content of the modeling files when using this argument. We recommend setting a specific `revision` (commit hash) when loading remote code to ensure you protect yourself from unverified updates to the repository. From 112b2d173a38390f0d03abc3fed8a1b7219850e2 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 16 Jan 2026 14:39:00 +0100 Subject: [PATCH 043/111] chore(ci): deactivates cron job on unbound dep tests (#2810) --- .github/workflows/unbound_deps_tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/unbound_deps_tests.yml b/.github/workflows/unbound_deps_tests.yml index e3ae71cc9..a75ecc121 100644 --- a/.github/workflows/unbound_deps_tests.yml +++ b/.github/workflows/unbound_deps_tests.yml @@ -20,8 +20,8 @@ on: workflow_dispatch: # Run on the 1st and 15th of every month at 09:00 UTC - schedule: - - cron: '0 2 1,15 * *' + # schedule: + # - cron: '0 2 1,15 * *' permissions: contents: read From 19dce784575fa36d853f152f87736599a5097fd9 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Fri, 16 Jan 2026 17:14:28 +0100 Subject: [PATCH 044/111] Refactor: Move PEFT config from training script to policy level (#2806) * move peft config from `lerobot_train` to policy level * Update src/lerobot/scripts/lerobot_train.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Michel Aractingi * copilot response * Change the polciy function to return targets rather than peft config.`_get_default_peft_targets()` override in PI0, PI0.5, SmolVLA * remove none check when building config dict --------- Signed-off-by: Michel Aractingi --- src/lerobot/policies/pi0/modeling_pi0.py | 11 ++ src/lerobot/policies/pi05/modeling_pi05.py | 11 ++ src/lerobot/policies/pretrained.py | 164 ++++++++++++++++++ .../policies/smolvla/modeling_smolvla.py | 22 +++ src/lerobot/scripts/lerobot_train.py | 90 +--------- 5 files changed, 211 insertions(+), 87 deletions(-) diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 0445d6c00..58b5dc07b 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -1297,3 +1297,14 @@ class PI0Policy(PreTrainedPolicy): loss = losses.mean() loss_dict["loss"] = loss.item() return loss, loss_dict + + def _get_default_peft_targets(self) -> dict[str, any]: + """Return default PEFT target modules for PI0 fine-tuning.""" + common_projections = ( + "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out" + ) + target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))" + return { + "target_modules": target_modules, + "modules_to_save": [], + } diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 11d8b4d68..104ec63bf 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -1270,3 +1270,14 @@ class PI05Policy(PreTrainedPolicy): loss = losses.mean() loss_dict["loss"] = loss.item() return loss, loss_dict + + def _get_default_peft_targets(self) -> dict[str, any]: + """Return default PEFT target modules for PI0.5 fine-tuning.""" + common_projections = ( + "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out" + ) + target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))" + return { + "target_modules": target_modules, + "modules_to_save": [], + } diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py index a1499d077..e730b78a7 100644 --- a/src/lerobot/policies/pretrained.py +++ b/src/lerobot/policies/pretrained.py @@ -13,6 +13,7 @@ # limitations under the License. import abc import builtins +import dataclasses import logging import os from importlib.resources import files @@ -265,3 +266,166 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): card = ModelCard.from_template(card_data, template_str=template_card) card.validate() return card + + def wrap_with_peft( + self, + peft_config=None, + peft_cli_overrides: dict | None = None, + ) -> "PreTrainedPolicy": + """ + Wrap this policy with PEFT adapters for parameter-efficient fine-tuning. + + This method is the single entry point for PEFT integration. Subclasses should + override `_get_default_peft_targets()` to provide default target modules, and + `_validate_peft_config()` for policy-specific validation. + + Args: + peft_config: Optional PEFT adapter configuration (e.g., LoraConfig). + If provided, used directly (with CLI overrides applied). + peft_cli_overrides: Optional dict of CLI overrides (method_type, target_modules, r, etc.) + These are merged with policy defaults to build the final config. + """ + from peft import get_peft_model + + # If user provided a complete config, use it directly (with overrides) + if peft_config is not None: + final_config = peft_config + if peft_cli_overrides: + final_config = self._apply_peft_cli_overrides(final_config, peft_cli_overrides) + else: + # Build config from defaults + CLI overrides + final_config = self._build_peft_config(peft_cli_overrides or {}) + + # Validate the configuration + self._validate_peft_config(final_config) + + # Freeze base parameters, only adapter params will be trained + for p in self.parameters(): + p.requires_grad_(False) + + # Store pretrained path for PEFT's base_model_name_or_path + if self.config.pretrained_path: + self.name_or_path = str(self.config.pretrained_path) + + # Wrap with PEFT + peft_model = get_peft_model(self, final_config) + + # Mark config as using PEFT for proper loading later + peft_model.config.use_peft = True + + logging.info(f"Wrapped {self.name} with PEFT ({type(final_config).__name__})") + return peft_model + + def _get_default_peft_targets(self) -> dict[str, any] | None: + """ + Return default PEFT target modules for this policy. + + Override this in subclasses to provide policy-specific defaults. These defaults + are PEFT-method agnostic - they only specify which modules to target. + + """ + return None + + def _validate_peft_config(self, peft_config) -> None: + """ + Validate the PEFT configuration for this policy. + + Override this in subclasses to add policy-specific validation or warnings. + The default implementation checks that a pretrained_path exists. + + Args: + peft_config: The PEFT configuration to validate. + + Raises: + ValueError: If the configuration is invalid. + """ + if not self.config.pretrained_path: + raise ValueError( + "Training from scratch using PEFT is unlikely to yield good results. " + "Supply a `policy.pretrained_path` to fine-tune an existing model." + ) + + def _preprocess_peft_cli_overrides(self, cli_overrides: dict, peft_method_type) -> dict: + """ + Preprocess CLI overrides: rename keys and handle method-specific init_type. + + Args: + cli_overrides: Dict of CLI options (will be copied, not mutated). + peft_method_type: The PeftType enum value for the PEFT method. + + Returns: + Preprocessed dict with renamed keys and init_type mapped to method-specific key. + """ + from peft import PeftType + + cli_overrides = cli_overrides.copy() + + # Handle the full_training_modules -> modules_to_save rename + if "full_training_modules" in cli_overrides: + cli_overrides["modules_to_save"] = cli_overrides.pop("full_training_modules") + + # Remove method_type as it's handled separately + cli_overrides.pop("method_type", None) + + # Handle init_type specially based on PEFT method + init_type = cli_overrides.pop("init_type", None) + if init_type is not None: + if peft_method_type == PeftType.LORA: + cli_overrides["init_lora_weights"] = init_type + elif peft_method_type == PeftType.MISS: + cli_overrides["init_weights"] = init_type + else: + raise ValueError(f"Init type '{init_type}' unknown for PEFT method {peft_method_type}.") + + return cli_overrides + + def _build_peft_config(self, cli_overrides: dict): + """Build a PEFT config from policy defaults and CLI overrides.""" + from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType + + # Determine PEFT method type (default to LORA) + method_type_str = cli_overrides.get("method_type") or "lora" + peft_method_type = PeftType[method_type_str.upper()] + peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type] + + # Preprocess CLI overrides + cli_overrides = self._preprocess_peft_cli_overrides(cli_overrides, peft_method_type) + + # Start with policy defaults, apply CLI overrides + config_dict = dict(self._get_default_peft_targets() or {}) + for key, value in cli_overrides.items(): + if value is not None: + config_dict[key] = value + + # Ensure we have target_modules + if not config_dict.get("target_modules"): + raise ValueError( + f"Policy '{self.name}' does not define default target_modules. " + "Please pass --peft.target_modules explicitly." + ) + + return peft_config_cls(**config_dict) + + def _apply_peft_cli_overrides(self, peft_config, cli_overrides: dict): + """Apply CLI overrides to an existing PEFT config.""" + from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType + + # Get method type from existing config or CLI override + method_type_str = cli_overrides.get("method_type") + if method_type_str: + peft_method_type = PeftType[method_type_str.upper()] + peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type] + else: + peft_method_type = PeftType(peft_config.peft_type) + peft_config_cls = type(peft_config) + + # Preprocess CLI overrides + cli_overrides = self._preprocess_peft_cli_overrides(cli_overrides, peft_method_type) + + # Start with existing config, apply CLI overrides + config_dict = {k: v for k, v in dataclasses.asdict(peft_config).items() if not k.startswith("_")} + for key, value in cli_overrides.items(): + if value is not None: + config_dict[key] = value + + return peft_config_cls(**config_dict) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index f998661f9..c611e9ba2 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -480,6 +480,28 @@ class SmolVLAPolicy(PreTrainedPolicy): actions = pad_vector(batch[ACTION], self.config.max_action_dim) return actions + def _get_default_peft_targets(self) -> dict[str, any]: + """Return default PEFT target modules for SmolVLA fine-tuning.""" + common_projections = ( + "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out" + ) + target_modules = rf"(model\.vlm_with_expert\.lm_expert\..*\.(q|v)_proj|model\.({common_projections}))" + return { + "target_modules": target_modules, + "modules_to_save": [], + } + + def _validate_peft_config(self, peft_config) -> None: + """Validate PEFT configuration for SmolVLA.""" + super()._validate_peft_config(peft_config) + if not self.config.load_vlm_weights: + import logging + + logging.warning( + "Training SmolVLA from scratch using PEFT. This is unlikely to yield good results. " + "Set `load_vlm_weights=True` to fine-tune the existing policy." + ) + def pad_tensor(tensor, max_len, pad_value=0): """ diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 55737d5d8..41f866ce8 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -148,92 +148,6 @@ def update_policy( return train_metrics, output_dict -def get_default_peft_configuration(policy_type): - """Build a basic PEFT configuration for the given policy type assuming that we train a policy from a checkpoint.""" - - common_projections = "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out" - - if policy_type == "smolvla": - return { - "target_modules": rf"(model\.vlm_with_expert\.lm_expert\..*\.(q|v)_proj|model\.({common_projections}))", - "modules_to_save": [], - } - elif policy_type in ("pi0", "pi05"): - return { - "target_modules": rf"(.*\.gemma_expert\..*\.self_attn.(q|v)_proj|model\.({common_projections}))", - "modules_to_save": [], - } - - return {"modules_to_save": None} - - -def wrap_policy_in_peft_model(cfg, policy): - from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType, get_peft_model - - # Disable all gradients because we'll only train the parameters selected by the PEFT method. - # Layers that should receive gradients anyway need to be listed in `modules_to_save`. - for p in policy.parameters(): - p.requires_grad_(False) - - if not cfg.policy.pretrained_path: - raise ValueError( - "Training from scratch using PEFT. This is unlikely to yield good results. " - "Supply a `policy.path` to fine-tune an existing model." - ) - - if cfg.policy.type == "smolvla" and not cfg.policy.load_vlm_weights: - logging.warning( - "Training SmolVLA from scratch using PEFT. This is unlikely to yield good results. Set " - "`load_vlm_weights=True` to fine-tune the existing policy." - ) - - peft_config_policy = get_default_peft_configuration(cfg.policy.type) - peft_config_cli = dataclasses.asdict(cfg.peft) if cfg.peft else {} - peft_config_cli["modules_to_save"] = peft_config_cli["full_training_modules"] # compatibility with PEFT - peft_method_type = PeftType[peft_config_cli["method_type"].upper()] - peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type] - - # Handle specific CLI overrides - for key in ["target_modules", "modules_to_save", "r"]: - if peft_config_cli[key] is not None: - peft_config_policy[key] = peft_config_cli[key] - - if "target_modules" not in peft_config_policy: - raise ValueError( - f"There is no default `target_modules` value for policy {cfg.policy.type}. Please pass it manually." - ) - - # Init method depends on the used PEFT method, your specific PEFT method - # might not be considered here, in that case an error is raised. - if peft_config_cli["init_type"] is not None: - if peft_method_type == "LORA": - peft_config_policy["init_lora_weights"] = peft_config_cli["init_type"] - elif peft_method_type == "MISS": - peft_config_policy["init_weights"] = peft_config_cli["init_type"] - else: - raise ValueError( - f"Init type {peft_config_cli['init_type']} unknown for PEFT method {peft_method_type}." - ) - - # PEFT uses this attribute to set adapter_config.base_name_or_path which we use for loading the - # correct base model in `make_policy` since in a PEFT loading setting we only get the path to the - # adapter, not the base model. - if policy.config.pretrained_path: - policy.name_or_path = str(policy.config.pretrained_path) - - # Finally wrap the policy in a PEFT model - policy = get_peft_model( - policy, - peft_config_cls(**peft_config_policy), - ) - - # Make sure that the config is tagged as using PEFT so that the loading code can take the - # appropriate steps to use the adapter weights and the PEFT config instead of the full model weights. - policy.config.use_peft = True - - return policy - - @parser.wrap() def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): """ @@ -326,7 +240,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): if cfg.peft is not None: logging.info("Using PEFT! Wrapping model.") - policy = wrap_policy_in_peft_model(cfg, policy) + # Convert CLI peft config to dict for overrides + peft_cli_overrides = dataclasses.asdict(cfg.peft) + policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides) # Wait for all processes to finish policy creation before continuing accelerator.wait_for_everyone() From 33910673ec59dda65a5a4e3b02ca7867bf4a552e Mon Sep 17 00:00:00 2001 From: Alex Tyshka Date: Fri, 16 Jan 2026 12:14:15 -0500 Subject: [PATCH 045/111] Bugfix: Add tests for image deletion and fix mixed image-video deletion (#2592) * Add tests for image deletion and fix mixed-image-video deletion * Fix docstring whitespace * Remove debug print Signed-off-by: Alex Tyshka * Remove inaccurate comment * Remove batched video test --------- Signed-off-by: Alex Tyshka --- src/lerobot/datasets/lerobot_dataset.py | 2 +- tests/datasets/test_datasets.py | 59 +++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index d9cc28b30..420117996 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -1498,7 +1498,7 @@ class LeRobotDataset(torch.utils.data.Dataset): episode_index = self.episode_buffer["episode_index"] if isinstance(episode_index, np.ndarray): episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0] - for cam_key in self.meta.camera_keys: + for cam_key in self.meta.image_keys: img_dir = self._get_image_file_dir(episode_index, cam_key) if img_dir.is_dir(): shutil.rmtree(img_dir) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 4c91c55c0..157a14851 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -352,6 +352,65 @@ def test_image_array_to_pil_image_wrong_range_float_0_255(): image_array_to_pil_image(image) +def test_tmp_image_deletion(tmp_path, empty_lerobot_dataset_factory): + """Verify temporary image directories are removed for image features after saving episode.""" + # Image feature: images should be deleted after saving episode + image_key = "image" + features_image = { + image_key: {"dtype": "image", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]} + } + ds_img = empty_lerobot_dataset_factory(root=tmp_path / "img", features=features_image) + ds_img.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) + ds_img.save_episode() + img_dir = ds_img._get_image_file_dir(0, image_key) + assert not img_dir.exists(), "Temporary image directory should be removed for image features" + + +def test_tmp_video_deletion(tmp_path, empty_lerobot_dataset_factory): + """Verify temporary image directories are removed for video encoding when `batch_encoding_size == 1`.""" + # Video feature: when batch_encoding_size == 1 temporary images should be deleted + vid_key = "video" + features_video = { + vid_key: {"dtype": "video", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]} + } + + ds_vid = empty_lerobot_dataset_factory(root=tmp_path / "vid", features=features_video) + ds_vid.batch_encoding_size = 1 + ds_vid.add_frame({vid_key: np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) + ds_vid.save_episode() + vid_img_dir = ds_vid._get_image_file_dir(0, vid_key) + assert not vid_img_dir.exists(), ( + "Temporary image directory should be removed when batch_encoding_size == 1" + ) + + +def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory): + """Verify temporary image directories are removed appropriately when both image and video features are present.""" + image_key = "image" + vid_key = "video" + features_mixed = { + image_key: {"dtype": "image", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]}, + vid_key: {"dtype": "video", "shape": DUMMY_HWC, "names": ["height", "width", "channels"]}, + } + ds_mixed = empty_lerobot_dataset_factory( + root=tmp_path / "mixed", features=features_mixed, batch_encoding_size=2 + ) + ds_mixed.add_frame( + { + "image": np.random.rand(*DUMMY_CHW), + "video": np.random.rand(*DUMMY_HWC), + "task": "Dummy task", + } + ) + ds_mixed.save_episode() + img_dir = ds_mixed._get_image_file_dir(0, image_key) + vid_img_dir = ds_mixed._get_image_file_dir(0, vid_key) + assert not img_dir.exists(), "Temporary image directory should be removed for image features" + assert vid_img_dir.exists(), ( + "Temporary image directory should not be removed for video features when batch_encoding_size == 2" + ) + + # TODO(aliberts): # - [ ] test various attributes & state from init and create # - [ ] test init with episodes and check num_frames From 77dc49b3a38e9d20e1096ca95a1c47745925a604 Mon Sep 17 00:00:00 2001 From: Alex Tyshka Date: Fri, 16 Jan 2026 12:14:54 -0500 Subject: [PATCH 046/111] Fix delta timestamps with episodes filter and add tests (#2612) --- src/lerobot/datasets/lerobot_dataset.py | 23 ++- tests/datasets/test_datasets.py | 199 ++++++++++++++++++++++++ 2 files changed, 218 insertions(+), 4 deletions(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 420117996..6798e7fd7 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -935,17 +935,30 @@ class LeRobotDataset(torch.utils.data.Dataset): else: return get_hf_features_from_features(self.features) - def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]: + def _get_query_indices( + self, abs_idx: int, ep_idx: int + ) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]: + """Compute query indices for delta timestamps. + + Args: + abs_idx: The absolute index in the full dataset (not the relative index in filtered episodes). + ep_idx: The episode index. + + Returns: + A tuple of (query_indices, padding) where: + - query_indices: Dict mapping keys to lists of absolute indices to query + - padding: Dict mapping "{key}_is_pad" to boolean tensors indicating padded positions + """ ep = self.meta.episodes[ep_idx] ep_start = ep["dataset_from_index"] ep_end = ep["dataset_to_index"] query_indices = { - key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx] + key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx] for key, delta_idx in self.delta_indices.items() } padding = { # Pad values outside of current episode range f"{key}_is_pad": torch.BoolTensor( - [(idx + delta < ep_start) | (idx + delta >= ep_end) for delta in delta_idx] + [(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx] ) for key, delta_idx in self.delta_indices.items() } @@ -1037,10 +1050,12 @@ class LeRobotDataset(torch.utils.data.Dataset): self._ensure_hf_dataset_loaded() item = self.hf_dataset[idx] ep_idx = item["episode_index"].item() + # Use the absolute index from the dataset for delta timestamp calculations + abs_idx = item["index"].item() query_indices = None if self.delta_indices is not None: - query_indices, padding = self._get_query_indices(idx, ep_idx) + query_indices, padding = self._get_query_indices(abs_idx, ep_idx) query_result = self._query_hf_dataset(query_indices) item = {**item, **padding} for key, val in query_result.items(): diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 157a14851..27c51b3c4 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -1451,3 +1451,202 @@ def test_valid_video_codecs_constant(): assert "hevc" in VALID_VIDEO_CODECS assert "libsvtav1" in VALID_VIDEO_CODECS assert len(VALID_VIDEO_CODECS) == 3 + + +def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory): + """Regression test for bug where delta_timestamps incorrectly marked all frames as padded when using episodes filter. + + The bug occurred because _get_query_indices was using the relative index (idx) in the filtered dataset + instead of the absolute index when comparing against episode boundaries (ep_start, ep_end). + """ + features = { + "observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}, + "action": {"dtype": "float32", "shape": (2,), "names": ["vx", "vy"]}, + } + + dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False) + + # Create 3 episodes with 10 frames each + frames_per_episode = 10 + for ep_idx in range(3): + for frame_idx in range(frames_per_episode): + dataset.add_frame( + { + "observation.state": torch.tensor([ep_idx, frame_idx], dtype=torch.float32), + "action": torch.randn(2), + "task": f"task_{ep_idx}", + } + ) + dataset.save_episode() + dataset.finalize() + + # Load only episode 1 (middle episode) with delta_timestamps + delta_ts = {"observation.state": [0.0]} # Just the current frame + filtered_dataset = LeRobotDataset( + dataset.repo_id, + root=dataset.root, + episodes=[1], + delta_timestamps=delta_ts, + ) + + # Verify the filtered dataset has the correct length + assert len(filtered_dataset) == frames_per_episode + + # Check that no frames are marked as padded (since delta=0 should always be valid) + for idx in range(len(filtered_dataset)): + frame = filtered_dataset[idx] + assert frame["observation.state_is_pad"].item() is False, f"Frame {idx} incorrectly marked as padded" + # Verify we're getting data from episode 1 + assert frame["episode_index"].item() == 1 + + +def test_delta_timestamps_padding_at_episode_boundaries(tmp_path, empty_lerobot_dataset_factory): + """Test that delta_timestamps correctly marks padding at episode boundaries when using episodes filter.""" + features = { + "observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}, + "action": {"dtype": "float32", "shape": (2,), "names": ["vx", "vy"]}, + } + + dataset = empty_lerobot_dataset_factory( + root=tmp_path / "test", features=features, use_videos=False, fps=10 + ) + + # Create 3 episodes with 5 frames each + frames_per_episode = 5 + for ep_idx in range(3): + for frame_idx in range(frames_per_episode): + dataset.add_frame( + { + "observation.state": torch.tensor([ep_idx, frame_idx], dtype=torch.float32), + "action": torch.randn(2), + "task": f"task_{ep_idx}", + } + ) + dataset.save_episode() + dataset.finalize() + + # Load only episode 1 with delta_timestamps that go beyond episode boundaries + # fps=10, so 0.1s = 1 frame offset + delta_ts = {"observation.state": [-0.2, -0.1, 0.0, 0.1, 0.2]} # -2, -1, 0, +1, +2 frames + filtered_dataset = LeRobotDataset( + dataset.repo_id, + root=dataset.root, + episodes=[1], + delta_timestamps=delta_ts, + tolerance_s=0.04, # Slightly less than half a frame at 10fps + ) + + assert len(filtered_dataset) == frames_per_episode + + # Check padding at the start of the episode (first frame) + first_frame = filtered_dataset[0] + is_pad = first_frame["observation.state_is_pad"].tolist() + # At frame 0 of episode 1: delta -2 and -1 should be padded, 0, +1, +2 should not + assert is_pad == [True, True, False, False, False], f"First frame padding incorrect: {is_pad}" + + # Check middle frame (no padding expected) + mid_frame = filtered_dataset[2] + is_pad = mid_frame["observation.state_is_pad"].tolist() + assert is_pad == [False, False, False, False, False], f"Middle frame padding incorrect: {is_pad}" + + # Check padding at the end of the episode (last frame) + last_frame = filtered_dataset[4] + is_pad = last_frame["observation.state_is_pad"].tolist() + # At frame 4 of episode 1: delta -2, -1, 0 should not be padded, +1, +2 should be + assert is_pad == [False, False, False, True, True], f"Last frame padding incorrect: {is_pad}" + + +def test_delta_timestamps_multiple_episodes_filter(tmp_path, empty_lerobot_dataset_factory): + """Test delta_timestamps with multiple non-consecutive episodes selected.""" + features = { + "observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]}, + } + + dataset = empty_lerobot_dataset_factory( + root=tmp_path / "test", features=features, use_videos=False, fps=10 + ) + + # Create 5 episodes with 5 frames each + frames_per_episode = 5 + for ep_idx in range(5): + for frame_idx in range(frames_per_episode): + dataset.add_frame( + { + "observation.state": torch.tensor([ep_idx, frame_idx], dtype=torch.float32), + "task": f"task_{ep_idx}", + } + ) + dataset.save_episode() + dataset.finalize() + + # Load episodes 1 and 3 (non-consecutive) + delta_ts = {"observation.state": [0.0]} + filtered_dataset = LeRobotDataset( + dataset.repo_id, + root=dataset.root, + episodes=[1, 3], + delta_timestamps=delta_ts, + ) + + assert len(filtered_dataset) == 2 * frames_per_episode + + # All frames should have valid (non-padded) data for delta=0 + for idx in range(len(filtered_dataset)): + frame = filtered_dataset[idx] + assert frame["observation.state_is_pad"].item() is False + + # Verify we're getting the correct episodes + episode_indices = [filtered_dataset[i]["episode_index"].item() for i in range(len(filtered_dataset))] + expected_episodes = [1] * frames_per_episode + [3] * frames_per_episode + assert episode_indices == expected_episodes + + +def test_delta_timestamps_query_returns_correct_values(tmp_path, empty_lerobot_dataset_factory): + """Test that delta_timestamps returns the correct observation values, not just correct padding.""" + features = { + "observation.state": {"dtype": "float32", "shape": (1,), "names": ["x"]}, + } + + dataset = empty_lerobot_dataset_factory( + root=tmp_path / "test", features=features, use_videos=False, fps=10 + ) + + # Create 2 episodes with known values + # Episode 0: frames with values 0, 1, 2, 3, 4 + # Episode 1: frames with values 10, 11, 12, 13, 14 + frames_per_episode = 5 + for ep_idx in range(2): + for frame_idx in range(frames_per_episode): + value = ep_idx * 10 + frame_idx + dataset.add_frame( + { + "observation.state": torch.tensor([value], dtype=torch.float32), + "task": f"task_{ep_idx}", + } + ) + dataset.save_episode() + dataset.finalize() + + # Load episode 1 with delta that looks at previous frame + delta_ts = {"observation.state": [-0.1, 0.0]} # Previous frame and current frame + filtered_dataset = LeRobotDataset( + dataset.repo_id, + root=dataset.root, + episodes=[1], + delta_timestamps=delta_ts, + tolerance_s=0.04, + ) + + # Check frame 2 of episode 1 (which has absolute index 7, value 12) + frame = filtered_dataset[2] + state_values = frame["observation.state"].tolist() + # Should get [11, 12] - the previous and current values within episode 1 + assert state_values == [11.0, 12.0], f"Expected [11.0, 12.0], got {state_values}" + + # Check first frame - previous frame should be clamped to episode start (padded) + first_frame = filtered_dataset[0] + state_values = first_frame["observation.state"].tolist() + is_pad = first_frame["observation.state_is_pad"].tolist() + # Previous frame is outside episode, so it's clamped to first frame and marked as padded + assert state_values == [10.0, 10.0], f"Expected [10.0, 10.0], got {state_values}" + assert is_pad == [True, False], f"Expected [True, False], got {is_pad}" From 46e19ae579f80ce66211afafd1c3c649c569131f Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 16 Jan 2026 18:52:06 +0100 Subject: [PATCH 047/111] feat: is connect checks decorators (#2813) --- src/lerobot/motors/motors_bus.py | 32 ++++----------- .../robot_earthrover_mini_plus.py | 15 +++---- src/lerobot/robots/hope_jr/hope_jr_arm.py | 17 +++----- src/lerobot/robots/hope_jr/hope_jr_hand.py | 18 +++----- .../robots/koch_follower/koch_follower.py | 16 +++----- src/lerobot/robots/lekiwi/lekiwi.py | 17 +++----- src/lerobot/robots/lekiwi/lekiwi_client.py | 22 +++------- .../robots/omx_follower/omx_follower.py | 16 +++----- src/lerobot/robots/so_follower/so_follower.py | 16 +++----- .../bi_so_leader/bi_so_leader.py | 6 +-- .../teleoperators/gamepad/teleop_gamepad.py | 6 +-- .../homunculus/homunculus_arm.py | 14 ++----- .../homunculus/homunculus_glove.py | 14 ++----- .../teleoperators/keyboard/teleop_keyboard.py | 31 +++----------- .../teleoperators/koch_leader/koch_leader.py | 14 ++----- .../teleoperators/omx_leader/omx_leader.py | 14 ++----- .../teleoperators/phone/teleop_phone.py | 26 ++++-------- .../reachy2_teleoperator.py | 11 ++--- .../teleoperators/so_leader/so_leader.py | 14 ++----- src/lerobot/utils/decorators.py | 41 +++++++++++++++++++ tests/mocks/mock_robot.py | 22 +++------- tests/mocks/mock_teleop.py | 23 ++++------- 22 files changed, 144 insertions(+), 261 deletions(-) create mode 100644 src/lerobot/utils/decorators.py diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py index 17eaa8063..91bee994a 100644 --- a/src/lerobot/motors/motors_bus.py +++ b/src/lerobot/motors/motors_bus.py @@ -32,7 +32,7 @@ import serial from deepdiff import DeepDiff from tqdm import tqdm -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.utils import enter_pressed, move_cursor_up NameOrID: TypeAlias = str | int @@ -411,6 +411,7 @@ class MotorsBus(abc.ABC): """bool: `True` if the underlying serial port is open.""" return self.port_handler.is_open + @check_if_already_connected def connect(self, handshake: bool = True) -> None: """Open the serial port and initialise communication. @@ -422,10 +423,6 @@ class MotorsBus(abc.ABC): DeviceAlreadyConnectedError: The port is already open. ConnectionError: The underlying SDK failed to open the port or the handshake did not succeed. """ - if self.is_connected: - raise DeviceAlreadyConnectedError( - f"{self.__class__.__name__}('{self.port}') is already connected. Do not call `{self.__class__.__name__}.connect()` twice." - ) self._connect(handshake) self.set_timeout() @@ -447,6 +444,7 @@ class MotorsBus(abc.ABC): def _handshake(self) -> None: pass + @check_if_not_connected def disconnect(self, disable_torque: bool = True) -> None: """Close the serial port (optionally disabling torque first). @@ -455,10 +453,6 @@ class MotorsBus(abc.ABC): closing the port. This can prevent damaging motors if they are left applying resisting torque after disconnect. """ - if not self.is_connected: - raise DeviceNotConnectedError( - f"{self.__class__.__name__}('{self.port}') is not connected. Try running `{self.__class__.__name__}.connect()` first." - ) if disable_torque: self.port_handler.clearPort() @@ -907,6 +901,7 @@ class MotorsBus(abc.ABC): """ pass + @check_if_not_connected def read( self, data_name: str, @@ -927,10 +922,6 @@ class MotorsBus(abc.ABC): Returns: Value: Raw or normalised value depending on *normalize*. """ - if not self.is_connected: - raise DeviceNotConnectedError( - f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." - ) id_ = self.motors[motor].id model = self.motors[motor].model @@ -981,6 +972,7 @@ class MotorsBus(abc.ABC): return value, comm, error + @check_if_not_connected def write( self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0 ) -> None: @@ -999,10 +991,6 @@ class MotorsBus(abc.ABC): normalize (bool, optional): Enable or disable normalisation. Defaults to `True`. num_retry (int, optional): Retry attempts. Defaults to `0`. """ - if not self.is_connected: - raise DeviceNotConnectedError( - f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." - ) id_ = self.motors[motor].id model = self.motors[motor].model @@ -1044,6 +1032,7 @@ class MotorsBus(abc.ABC): return comm, error + @check_if_not_connected def sync_read( self, data_name: str, @@ -1063,10 +1052,6 @@ class MotorsBus(abc.ABC): Returns: dict[str, Value]: Mapping *motor name → value*. """ - if not self.is_connected: - raise DeviceNotConnectedError( - f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." - ) self._assert_protocol_is_compatible("sync_read") @@ -1139,6 +1124,7 @@ class MotorsBus(abc.ABC): # for id_ in motor_ids: # value = self.sync_reader.getData(id_, address, length) + @check_if_not_connected def sync_write( self, data_name: str, @@ -1160,10 +1146,6 @@ class MotorsBus(abc.ABC): normalize (bool, optional): If `True` (default) convert values from the user range to raw units. num_retry (int, optional): Retry attempts. Defaults to `0`. """ - if not self.is_connected: - raise DeviceNotConnectedError( - f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." - ) ids_values = self._get_ids_values_dict(values) models = [self._id_to_model(id_) for id_ in ids_values] diff --git a/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py b/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py index 784a95577..cdf6efde1 100644 --- a/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py +++ b/src/lerobot/robots/earthrover_mini_plus/robot_earthrover_mini_plus.py @@ -24,7 +24,8 @@ import numpy as np import requests from lerobot.processor import RobotAction, RobotObservation -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.errors import DeviceNotConnectedError from ..robot import Robot from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig @@ -99,6 +100,7 @@ class EarthRoverMiniPlus(Robot): """Check if robot is connected to SDK.""" return self._is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: """Connect to robot via Frodobots SDK. @@ -109,8 +111,6 @@ class EarthRoverMiniPlus(Robot): DeviceAlreadyConnectedError: If robot is already connected DeviceNotConnectedError: If cannot connect to SDK server """ - if self._is_connected: - raise DeviceAlreadyConnectedError(f"{self.name} is already connected") # Verify SDK is running and accessible try: @@ -197,6 +197,7 @@ class EarthRoverMiniPlus(Robot): ACTION_ANGULAR_VEL: float, } + @check_if_not_connected def get_observation(self) -> RobotObservation: """Get current robot observation from SDK. @@ -223,8 +224,6 @@ class EarthRoverMiniPlus(Robot): Robot telemetry is retrieved from /data endpoint. All SDK values are normalized to appropriate ranges for dataset recording. """ - if not self._is_connected: - raise DeviceNotConnectedError(f"{self.name} is not connected") observation = {} @@ -255,6 +254,7 @@ class EarthRoverMiniPlus(Robot): return observation + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: """Send action to robot via SDK. @@ -272,8 +272,6 @@ class EarthRoverMiniPlus(Robot): Actions are sent to SDK via POST /control endpoint. SDK expects commands in range [-1, 1]. """ - if not self._is_connected: - raise DeviceNotConnectedError(f"{self.name} is not connected") # Extract action values and convert to float linear = float(action.get(ACTION_LINEAR_VEL, 0.0)) @@ -291,6 +289,7 @@ class EarthRoverMiniPlus(Robot): ACTION_ANGULAR_VEL: angular, } + @check_if_not_connected def disconnect(self) -> None: """Disconnect from robot. @@ -299,8 +298,6 @@ class EarthRoverMiniPlus(Robot): Raises: DeviceNotConnectedError: If robot is not connected """ - if not self._is_connected: - raise DeviceNotConnectedError(f"{self.name} is not connected") # Stop the robot before disconnecting try: diff --git a/src/lerobot/robots/hope_jr/hope_jr_arm.py b/src/lerobot/robots/hope_jr/hope_jr_arm.py index 4be8a0b17..5fd9c4d1d 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_arm.py +++ b/src/lerobot/robots/hope_jr/hope_jr_arm.py @@ -25,7 +25,7 @@ from lerobot.motors.feetech import ( FeetechMotorsBus, ) from lerobot.processor import RobotAction, RobotObservation -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from ..utils import ensure_safe_goal_position @@ -82,13 +82,12 @@ class HopeJrArm(Robot): def is_connected(self) -> bool: return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: """ We assume that at connection time, arm is in a rest position, and torque can be safely disabled to run calibration. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") self.bus.connect(handshake=False) if not self.is_calibrated and calibrate: @@ -128,10 +127,8 @@ class HopeJrArm(Robot): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + @check_if_not_connected def get_observation(self) -> RobotObservation: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - # Read arm position start = time.perf_counter() obs_dict = self.bus.sync_read("Present_Position", self.other_motors) @@ -149,10 +146,8 @@ class HopeJrArm(Robot): return obs_dict + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} # Cap goal position when too far away from present position. @@ -165,10 +160,8 @@ class HopeJrArm(Robot): self.bus.sync_write("Goal_Position", goal_pos) return {f"{motor}.pos": val for motor, val in goal_pos.items()} + @check_if_not_connected def disconnect(self): - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self.bus.disconnect(self.config.disable_torque_on_disconnect) for cam in self.cameras.values(): cam.disconnect() diff --git a/src/lerobot/robots/hope_jr/hope_jr_hand.py b/src/lerobot/robots/hope_jr/hope_jr_hand.py index 73fb4464f..1e5c72b72 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_hand.py +++ b/src/lerobot/robots/hope_jr/hope_jr_hand.py @@ -25,7 +25,7 @@ from lerobot.motors.feetech import ( FeetechMotorsBus, ) from lerobot.processor import RobotAction, RobotObservation -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from .config_hope_jr import HopeJrHandConfig @@ -118,10 +118,8 @@ class HopeJrHand(Robot): def is_connected(self) -> bool: return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - self.bus.connect() if not self.is_calibrated and calibrate: self.calibrate() @@ -159,10 +157,8 @@ class HopeJrHand(Robot): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + @check_if_not_connected def get_observation(self) -> RobotObservation: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - obs_dict = {} # Read hand position @@ -181,18 +177,14 @@ class HopeJrHand(Robot): return obs_dict + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} self.bus.sync_write("Goal_Position", goal_pos) return action + @check_if_not_connected def disconnect(self): - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self.bus.disconnect(self.config.disable_torque_on_disconnect) for cam in self.cameras.values(): cam.disconnect() diff --git a/src/lerobot/robots/koch_follower/koch_follower.py b/src/lerobot/robots/koch_follower/koch_follower.py index a1d001ba8..fee0adba9 100644 --- a/src/lerobot/robots/koch_follower/koch_follower.py +++ b/src/lerobot/robots/koch_follower/koch_follower.py @@ -25,7 +25,7 @@ from lerobot.motors.dynamixel import ( OperatingMode, ) from lerobot.processor import RobotAction, RobotObservation -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from ..utils import ensure_safe_goal_position @@ -84,13 +84,12 @@ class KochFollower(Robot): def is_connected(self) -> bool: return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: """ We assume that at connection time, arm is in a rest position, and torque can be safely disabled to run calibration. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") self.bus.connect() if not self.is_calibrated and calibrate: @@ -182,10 +181,8 @@ class KochFollower(Robot): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + @check_if_not_connected def get_observation(self) -> RobotObservation: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - # Read arm position start = time.perf_counter() obs_dict = self.bus.sync_read("Present_Position") @@ -202,6 +199,7 @@ class KochFollower(Robot): return obs_dict + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: """Command arm to move to a target joint configuration. @@ -215,8 +213,6 @@ class KochFollower(Robot): Returns: RobotAction: The action sent to the motors, potentially clipped. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} @@ -231,10 +227,8 @@ class KochFollower(Robot): self.bus.sync_write("Goal_Position", goal_pos) return {f"{motor}.pos": val for motor, val in goal_pos.items()} + @check_if_not_connected def disconnect(self): - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self.bus.disconnect(self.config.disable_torque_on_disconnect) for cam in self.cameras.values(): cam.disconnect() diff --git a/src/lerobot/robots/lekiwi/lekiwi.py b/src/lerobot/robots/lekiwi/lekiwi.py index c84e81001..54848f49d 100644 --- a/src/lerobot/robots/lekiwi/lekiwi.py +++ b/src/lerobot/robots/lekiwi/lekiwi.py @@ -29,7 +29,7 @@ from lerobot.motors.feetech import ( OperatingMode, ) from lerobot.processor import RobotAction, RobotObservation -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from ..utils import ensure_safe_goal_position @@ -109,10 +109,8 @@ class LeKiwi(Robot): def is_connected(self) -> bool: return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - self.bus.connect() if not self.is_calibrated and calibrate: logger.info( @@ -339,10 +337,8 @@ class LeKiwi(Robot): "theta.vel": theta, } # m/s and deg/s + @check_if_not_connected def get_observation(self) -> RobotObservation: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - # Read actuators position for arm and vel for base start = time.perf_counter() arm_pos = self.bus.sync_read("Present_Position", self.arm_motors) @@ -370,6 +366,7 @@ class LeKiwi(Robot): return obs_dict + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: """Command lekiwi to move to a target joint configuration. @@ -383,8 +380,6 @@ class LeKiwi(Robot): Returns: RobotAction: the action sent to the motors, potentially clipped. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") arm_goal_pos = {k: v for k, v in action.items() if k.endswith(".pos")} base_goal_vel = {k: v for k, v in action.items() if k.endswith(".vel")} @@ -412,10 +407,8 @@ class LeKiwi(Robot): self.bus.sync_write("Goal_Velocity", dict.fromkeys(self.base_motors, 0), num_retry=5) logger.info("Base motors stopped") + @check_if_not_connected def disconnect(self): - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self.stop_base() self.bus.disconnect(self.config.disable_torque_on_disconnect) for cam in self.cameras.values(): diff --git a/src/lerobot/robots/lekiwi/lekiwi_client.py b/src/lerobot/robots/lekiwi/lekiwi_client.py index bb865dc10..1d5ea64a6 100644 --- a/src/lerobot/robots/lekiwi/lekiwi_client.py +++ b/src/lerobot/robots/lekiwi/lekiwi_client.py @@ -24,7 +24,8 @@ import numpy as np from lerobot.processor import RobotAction, RobotObservation from lerobot.utils.constants import ACTION, OBS_STATE -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.errors import DeviceNotConnectedError from ..robot import Robot from .config_lekiwi import LeKiwiClientConfig @@ -112,14 +113,10 @@ class LeKiwiClient(Robot): def is_calibrated(self) -> bool: pass + @check_if_already_connected def connect(self) -> None: """Establishes ZMQ sockets with the remote mobile robot""" - if self._is_connected: - raise DeviceAlreadyConnectedError( - "LeKiwi Daemon is already connected. Do not run `robot.connect()` twice." - ) - zmq = self._zmq self.zmq_context = zmq.Context() self.zmq_cmd_socket = self.zmq_context.socket(zmq.PUSH) @@ -252,14 +249,13 @@ class LeKiwiClient(Robot): return new_frames, new_state + @check_if_not_connected def get_observation(self) -> RobotObservation: """ Capture observations from the remote robot: current follower arm positions, present wheel speeds (converted to body-frame velocities: x, y, theta), and a camera frame. Receives over ZMQ, translate to body-frame vel """ - if not self._is_connected: - raise DeviceNotConnectedError("LeKiwiClient is not connected. You need to run `robot.connect()`.") frames, obs_dict = self._get_data() @@ -307,6 +303,7 @@ class LeKiwiClient(Robot): def configure(self): pass + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: """Command lekiwi to move to a target joint configuration. Translates to motor space + sends over ZMQ @@ -318,10 +315,6 @@ class LeKiwiClient(Robot): Returns: np.ndarray: the action sent to the motors, potentially clipped. """ - if not self._is_connected: - raise DeviceNotConnectedError( - "ManipulatorRobot is not connected. You need to run `robot.connect()`." - ) self.zmq_cmd_socket.send_string(json.dumps(action)) # action is in motor space @@ -332,13 +325,10 @@ class LeKiwiClient(Robot): action_sent[ACTION] = actions return action_sent + @check_if_not_connected def disconnect(self): """Cleans ZMQ comms""" - if not self._is_connected: - raise DeviceNotConnectedError( - "LeKiwi is not connected. You need to run `robot.connect()` before disconnecting." - ) self.zmq_observation_socket.close() self.zmq_cmd_socket.close() self.zmq_context.term() diff --git a/src/lerobot/robots/omx_follower/omx_follower.py b/src/lerobot/robots/omx_follower/omx_follower.py index 14668b3a7..a171affbd 100644 --- a/src/lerobot/robots/omx_follower/omx_follower.py +++ b/src/lerobot/robots/omx_follower/omx_follower.py @@ -26,7 +26,7 @@ from lerobot.motors.dynamixel import ( OperatingMode, ) from lerobot.processor import RobotAction, RobotObservation -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from ..utils import ensure_safe_goal_position @@ -84,6 +84,7 @@ class OmxFollower(Robot): def is_connected(self) -> bool: return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: """ For OMX robots that come pre-calibrated: @@ -91,8 +92,6 @@ class OmxFollower(Robot): - This allows using pre-calibrated robots without manual calibration - If no calibration file exists, use factory default values (homing_offset=0, range_min=0, range_max=4095) """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") self.bus.connect() if not self.is_calibrated and calibrate: @@ -165,10 +164,8 @@ class OmxFollower(Robot): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + @check_if_not_connected def get_observation(self) -> RobotObservation: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - # Read arm position start = time.perf_counter() obs_dict = self.bus.sync_read("Present_Position") @@ -185,6 +182,7 @@ class OmxFollower(Robot): return obs_dict + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: """Command arm to move to a target joint configuration. @@ -198,8 +196,6 @@ class OmxFollower(Robot): Returns: RobotAction: The action sent to the motors, potentially clipped. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} @@ -214,10 +210,8 @@ class OmxFollower(Robot): self.bus.sync_write("Goal_Position", goal_pos) return {f"{motor}.pos": val for motor, val in goal_pos.items()} + @check_if_not_connected def disconnect(self): - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self.bus.disconnect(self.config.disable_torque_on_disconnect) for cam in self.cameras.values(): cam.disconnect() diff --git a/src/lerobot/robots/so_follower/so_follower.py b/src/lerobot/robots/so_follower/so_follower.py index 011a0061e..b4d11fe3f 100644 --- a/src/lerobot/robots/so_follower/so_follower.py +++ b/src/lerobot/robots/so_follower/so_follower.py @@ -26,7 +26,7 @@ from lerobot.motors.feetech import ( OperatingMode, ) from lerobot.processor import RobotAction, RobotObservation -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from ..utils import ensure_safe_goal_position @@ -85,13 +85,12 @@ class SOFollower(Robot): def is_connected(self) -> bool: return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: """ We assume that at connection time, arm is in a rest position, and torque can be safely disabled to run calibration. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") self.bus.connect() if not self.is_calibrated and calibrate: @@ -176,10 +175,8 @@ class SOFollower(Robot): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + @check_if_not_connected def get_observation(self) -> RobotObservation: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - # Read arm position start = time.perf_counter() obs_dict = self.bus.sync_read("Present_Position") @@ -196,6 +193,7 @@ class SOFollower(Robot): return obs_dict + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: """Command arm to move to a target joint configuration. @@ -209,8 +207,6 @@ class SOFollower(Robot): Returns: RobotAction: the action sent to the motors, potentially clipped. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} @@ -225,10 +221,8 @@ class SOFollower(Robot): self.bus.sync_write("Goal_Position", goal_pos) return {f"{motor}.pos": val for motor, val in goal_pos.items()} + @check_if_not_connected def disconnect(self): - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self.bus.disconnect(self.config.disable_torque_on_disconnect) for cam in self.cameras.values(): cam.disconnect() diff --git a/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py b/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py index 0fdc32461..90bf2a92d 100644 --- a/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py +++ b/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py @@ -18,7 +18,7 @@ import logging from functools import cached_property from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig -from lerobot.utils.errors import DeviceNotConnectedError +from lerobot.utils.decorators import check_if_not_connected from ..so_leader import SOLeader from ..teleoperator import Teleoperator @@ -92,10 +92,8 @@ class BiSOLeader(Teleoperator): self.left_arm.setup_motors() self.right_arm.setup_motors() + @check_if_not_connected def get_action(self) -> dict[str, float]: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - action_dict = {} # Add "left_" prefix diff --git a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py index 14d8e3cc1..69cb0f971 100644 --- a/src/lerobot/teleoperators/gamepad/teleop_gamepad.py +++ b/src/lerobot/teleoperators/gamepad/teleop_gamepad.py @@ -21,7 +21,7 @@ from typing import Any import numpy as np from lerobot.processor import RobotAction -from lerobot.utils.errors import DeviceNotConnectedError +from lerobot.utils.decorators import check_if_not_connected from ..teleoperator import Teleoperator from ..utils import TeleopEvents @@ -86,10 +86,8 @@ class GamepadTeleop(Teleoperator): self.gamepad = Gamepad() self.gamepad.start() + @check_if_not_connected def get_action(self) -> RobotAction: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - # Update the controller to get fresh inputs self.gamepad.update() diff --git a/src/lerobot/teleoperators/homunculus/homunculus_arm.py b/src/lerobot/teleoperators/homunculus/homunculus_arm.py index 53b4dfe68..178eed544 100644 --- a/src/lerobot/teleoperators/homunculus/homunculus_arm.py +++ b/src/lerobot/teleoperators/homunculus/homunculus_arm.py @@ -22,7 +22,7 @@ from pprint import pformat import serial from lerobot.motors.motors_bus import MotorCalibration, MotorNormMode -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.utils import enter_pressed, move_cursor_up from ..teleoperator import Teleoperator @@ -93,10 +93,8 @@ class HomunculusArm(Teleoperator): with self.serial_lock: return self.serial.is_open and self.thread.is_alive() + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - if not self.serial.is_open: self.serial.open() self.thread.start() @@ -299,20 +297,16 @@ class HomunculusArm(Teleoperator): except Exception as e: logger.debug(f"Error reading frame in background thread for {self}: {e}") + @check_if_not_connected def get_action(self) -> dict[str, float]: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - joint_positions = self._read() return {f"{joint}.pos": pos for joint, pos in joint_positions.items()} def send_feedback(self, feedback: dict[str, float]) -> None: raise NotImplementedError + @check_if_not_connected def disconnect(self) -> None: - if not self.is_connected: - DeviceNotConnectedError(f"{self} is not connected.") - self.stop_event.set() self.thread.join(timeout=1) self.serial.close() diff --git a/src/lerobot/teleoperators/homunculus/homunculus_glove.py b/src/lerobot/teleoperators/homunculus/homunculus_glove.py index 7feb926ef..c4393d660 100644 --- a/src/lerobot/teleoperators/homunculus/homunculus_glove.py +++ b/src/lerobot/teleoperators/homunculus/homunculus_glove.py @@ -24,7 +24,7 @@ import serial from lerobot.motors import MotorCalibration from lerobot.motors.motors_bus import MotorNormMode from lerobot.teleoperators.homunculus.joints_translation import homunculus_glove_to_hope_jr_hand -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.utils import enter_pressed, move_cursor_up from ..teleoperator import Teleoperator @@ -119,10 +119,8 @@ class HomunculusGlove(Teleoperator): with self.serial_lock: return self.serial.is_open and self.thread.is_alive() + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - if not self.serial.is_open: self.serial.open() self.thread.start() @@ -325,10 +323,8 @@ class HomunculusGlove(Teleoperator): except Exception as e: logger.debug(f"Error reading frame in background thread for {self}: {e}") + @check_if_not_connected def get_action(self) -> dict[str, float]: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - joint_positions = self._read() return homunculus_glove_to_hope_jr_hand( {f"{joint}.pos": pos for joint, pos in joint_positions.items()} @@ -337,10 +333,8 @@ class HomunculusGlove(Teleoperator): def send_feedback(self, feedback: dict[str, float]) -> None: raise NotImplementedError + @check_if_not_connected def disconnect(self) -> None: - if not self.is_connected: - DeviceNotConnectedError(f"{self} is not connected.") - self.stop_event.set() self.thread.join(timeout=1) self.serial.close() diff --git a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py index 55c158da8..919f463d3 100644 --- a/src/lerobot/teleoperators/keyboard/teleop_keyboard.py +++ b/src/lerobot/teleoperators/keyboard/teleop_keyboard.py @@ -22,7 +22,7 @@ from queue import Queue from typing import Any from lerobot.processor import RobotAction -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..teleoperator import Teleoperator from ..utils import TeleopEvents @@ -86,12 +86,8 @@ class KeyboardTeleop(Teleoperator): def is_calibrated(self) -> bool: pass + @check_if_already_connected def connect(self) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError( - "Keyboard is already connected. Do not run `robot.connect()` twice." - ) - if PYNPUT_AVAILABLE: logging.info("pynput is available - enabling local keyboard listener.") self.listener = keyboard.Listener( @@ -125,14 +121,10 @@ class KeyboardTeleop(Teleoperator): def configure(self): pass + @check_if_not_connected def get_action(self) -> RobotAction: before_read_t = time.perf_counter() - if not self.is_connected: - raise DeviceNotConnectedError( - "KeyboardTeleop is not connected. You need to run `connect()` before `get_action()`." - ) - self._drain_pressed_keys() # Generate action based on current key states @@ -144,11 +136,8 @@ class KeyboardTeleop(Teleoperator): def send_feedback(self, feedback: dict[str, Any]) -> None: pass + @check_if_not_connected def disconnect(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError( - "KeyboardTeleop is not connected. You need to run `robot.connect()` before `disconnect()`." - ) if self.listener is not None: self.listener.stop() @@ -182,12 +171,8 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop): "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2}, } + @check_if_not_connected def get_action(self) -> RobotAction: - if not self.is_connected: - raise DeviceNotConnectedError( - "KeyboardTeleop is not connected. You need to run `connect()` before `get_action()`." - ) - self._drain_pressed_keys() delta_x = 0.0 delta_y = 0.0 @@ -375,6 +360,7 @@ class KeyboardRoverTeleop(KeyboardTeleop): # Only remove key if it's being released self.current_pressed.pop(key_char, None) + @check_if_not_connected def get_action(self) -> RobotAction: """ Get the current action based on pressed keys. @@ -384,11 +370,6 @@ class KeyboardRoverTeleop(KeyboardTeleop): """ before_read_t = time.perf_counter() - if not self.is_connected: - raise DeviceNotConnectedError( - "KeyboardRoverTeleop is not connected. You need to run `connect()` before `get_action()`." - ) - self._drain_pressed_keys() linear_velocity = 0.0 diff --git a/src/lerobot/teleoperators/koch_leader/koch_leader.py b/src/lerobot/teleoperators/koch_leader/koch_leader.py index 0409f2e57..87084b6b9 100644 --- a/src/lerobot/teleoperators/koch_leader/koch_leader.py +++ b/src/lerobot/teleoperators/koch_leader/koch_leader.py @@ -23,7 +23,7 @@ from lerobot.motors.dynamixel import ( DynamixelMotorsBus, OperatingMode, ) -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..teleoperator import Teleoperator from .config_koch_leader import KochLeaderConfig @@ -69,10 +69,8 @@ class KochLeader(Teleoperator): def is_connected(self) -> bool: return self.bus.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - self.bus.connect() if not self.is_calibrated and calibrate: logger.info( @@ -161,10 +159,8 @@ class KochLeader(Teleoperator): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + @check_if_not_connected def get_action(self) -> dict[str, float]: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - start = time.perf_counter() action = self.bus.sync_read("Present_Position") action = {f"{motor}.pos": val for motor, val in action.items()} @@ -176,9 +172,7 @@ class KochLeader(Teleoperator): # TODO(rcadene, aliberts): Implement force feedback raise NotImplementedError + @check_if_not_connected def disconnect(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self.bus.disconnect() logger.info(f"{self} disconnected.") diff --git a/src/lerobot/teleoperators/omx_leader/omx_leader.py b/src/lerobot/teleoperators/omx_leader/omx_leader.py index c0e49b558..4423be714 100644 --- a/src/lerobot/teleoperators/omx_leader/omx_leader.py +++ b/src/lerobot/teleoperators/omx_leader/omx_leader.py @@ -23,7 +23,7 @@ from lerobot.motors.dynamixel import ( DynamixelMotorsBus, OperatingMode, ) -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..teleoperator import Teleoperator from .config_omx_leader import OmxLeaderConfig @@ -68,10 +68,8 @@ class OmxLeader(Teleoperator): def is_connected(self) -> bool: return self.bus.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - self.bus.connect() if not self.is_calibrated and calibrate: logger.info( @@ -142,10 +140,8 @@ class OmxLeader(Teleoperator): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + @check_if_not_connected def get_action(self) -> dict[str, float]: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - start = time.perf_counter() action = self.bus.sync_read("Present_Position") action = {f"{motor}.pos": val for motor, val in action.items()} @@ -157,9 +153,7 @@ class OmxLeader(Teleoperator): # TODO(rcadene, aliberts): Implement force feedback raise NotImplementedError + @check_if_not_connected def disconnect(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self.bus.disconnect() logger.info(f"{self} disconnected.") diff --git a/src/lerobot/teleoperators/phone/teleop_phone.py b/src/lerobot/teleoperators/phone/teleop_phone.py index 402a88d43..221ee8083 100644 --- a/src/lerobot/teleoperators/phone/teleop_phone.py +++ b/src/lerobot/teleoperators/phone/teleop_phone.py @@ -28,7 +28,7 @@ from teleop import Teleop from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS from lerobot.teleoperators.teleoperator import Teleoperator -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.rotation import Rotation logger = logging.getLogger(__name__) @@ -81,10 +81,8 @@ class IOSPhone(BasePhone, Teleoperator): def is_connected(self) -> bool: return self._group is not None + @check_if_already_connected def connect(self) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - logger.info("Connecting to IPhone, make sure to open the HEBI Mobile I/O app.") lookup = hebi.Lookup() time.sleep(2.0) @@ -164,10 +162,8 @@ class IOSPhone(BasePhone, Teleoperator): pos = ar_pos - rot.apply(self.config.camera_offset) return True, pos, rot, pose + @check_if_not_connected def get_action(self) -> dict: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - has_pose, raw_position, raw_rotation, fb_pose = self._read_current_pose() if not has_pose or not self.is_calibrated: return {} @@ -207,10 +203,8 @@ class IOSPhone(BasePhone, Teleoperator): "phone.enabled": self._enabled, } + @check_if_not_connected def disconnect(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self._group = None @@ -230,10 +224,8 @@ class AndroidPhone(BasePhone, Teleoperator): def is_connected(self) -> bool: return self._teleop is not None + @check_if_already_connected def connect(self) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - logger.info("Starting teleop stream for Android...") self._teleop = Teleop() self._teleop.subscribe(self._android_callback) @@ -321,10 +313,8 @@ class AndroidPhone(BasePhone, Teleoperator): self._latest_pose = pose self._latest_message = message + @check_if_not_connected def get_action(self) -> dict: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - ok, raw_pos, raw_rot, pose = self._read_current_pose() if not ok or not self.is_calibrated: return {} @@ -356,10 +346,8 @@ class AndroidPhone(BasePhone, Teleoperator): "phone.enabled": self._enabled, } + @check_if_not_connected def disconnect(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self._teleop = None if self._teleop_thread and self._teleop_thread.is_alive(): self._teleop_thread.join(timeout=1.0) diff --git a/src/lerobot/teleoperators/reachy2_teleoperator/reachy2_teleoperator.py b/src/lerobot/teleoperators/reachy2_teleoperator/reachy2_teleoperator.py index 578aaa7b2..db076b20f 100644 --- a/src/lerobot/teleoperators/reachy2_teleoperator/reachy2_teleoperator.py +++ b/src/lerobot/teleoperators/reachy2_teleoperator/reachy2_teleoperator.py @@ -26,7 +26,8 @@ if TYPE_CHECKING or _reachy2_sdk_available: else: ReachySDK = None -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.errors import DeviceNotConnectedError from ..teleoperator import Teleoperator from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig @@ -126,10 +127,8 @@ class Reachy2Teleoperator(Teleoperator): def is_connected(self) -> bool: return self.reachy.is_connected() if self.reachy is not None else False + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - self.reachy = ReachySDK(self.config.ip_address) if not self.is_connected: @@ -146,12 +145,10 @@ class Reachy2Teleoperator(Teleoperator): def configure(self) -> None: pass + @check_if_not_connected def get_action(self) -> dict[str, float]: start = time.perf_counter() - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - joint_action: dict[str, float] = {} vel_action: dict[str, float] = {} diff --git a/src/lerobot/teleoperators/so_leader/so_leader.py b/src/lerobot/teleoperators/so_leader/so_leader.py index d09262c74..a10e3a61f 100644 --- a/src/lerobot/teleoperators/so_leader/so_leader.py +++ b/src/lerobot/teleoperators/so_leader/so_leader.py @@ -23,7 +23,7 @@ from lerobot.motors.feetech import ( FeetechMotorsBus, OperatingMode, ) -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..teleoperator import Teleoperator from .config_so_leader import SOLeaderTeleopConfig @@ -66,10 +66,8 @@ class SOLeader(Teleoperator): def is_connected(self) -> bool: return self.bus.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - self.bus.connect() if not self.is_calibrated and calibrate: logger.info( @@ -139,10 +137,8 @@ class SOLeader(Teleoperator): self.bus.setup_motor(motor) print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + @check_if_not_connected def get_action(self) -> dict[str, float]: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - start = time.perf_counter() action = self.bus.sync_read("Present_Position") action = {f"{motor}.pos": val for motor, val in action.items()} @@ -154,10 +150,8 @@ class SOLeader(Teleoperator): # TODO: Implement force feedback raise NotImplementedError + @check_if_not_connected def disconnect(self) -> None: - if not self.is_connected: - DeviceNotConnectedError(f"{self} is not connected.") - self.bus.disconnect() logger.info(f"{self} disconnected.") diff --git a/src/lerobot/utils/decorators.py b/src/lerobot/utils/decorators.py new file mode 100644 index 000000000..8fc2f9a07 --- /dev/null +++ b/src/lerobot/utils/decorators.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import wraps + +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + + +def check_if_not_connected(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if not self.is_connected: + raise DeviceNotConnectedError( + f"{self.__class__.__name__} is not connected. Run `.connect()` first." + ) + return func(self, *args, **kwargs) + + return wrapper + + +def check_if_already_connected(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self.__class__.__name__} is already connected.") + return func(self, *args, **kwargs) + + return wrapper diff --git a/tests/mocks/mock_robot.py b/tests/mocks/mock_robot.py index d997cb6d4..f69a2c02a 100644 --- a/tests/mocks/mock_robot.py +++ b/tests/mocks/mock_robot.py @@ -22,7 +22,7 @@ from lerobot.cameras import CameraConfig, make_cameras_from_configs from lerobot.motors.motors_bus import Motor, MotorNormMode from lerobot.processor import RobotAction, RobotObservation from lerobot.robots import Robot, RobotConfig -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from tests.mocks.mock_motors_bus import MockMotorsBus @@ -98,10 +98,8 @@ class MockRobot(Robot): def is_connected(self) -> bool: return self._is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - self._is_connected = True if calibrate: self.calibrate() @@ -110,19 +108,15 @@ class MockRobot(Robot): def is_calibrated(self) -> bool: return self._is_calibrated + @check_if_not_connected def calibrate(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self._is_calibrated = True def configure(self) -> None: pass + @check_if_not_connected def get_observation(self) -> RobotObservation: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.config.random_values: return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors} else: @@ -130,14 +124,10 @@ class MockRobot(Robot): f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True) } + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - return action + @check_if_not_connected def disconnect(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self._is_connected = False diff --git a/tests/mocks/mock_teleop.py b/tests/mocks/mock_teleop.py index 04479bad9..89174dadf 100644 --- a/tests/mocks/mock_teleop.py +++ b/tests/mocks/mock_teleop.py @@ -21,7 +21,7 @@ from typing import Any from lerobot.processor import RobotAction from lerobot.teleoperators import Teleoperator, TeleoperatorConfig -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected @TeleoperatorConfig.register_subclass("mock_teleop") @@ -68,10 +68,8 @@ class MockTeleop(Teleoperator): def is_connected(self) -> bool: return self._is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") - self._is_connected = True if calibrate: self.calibrate() @@ -80,19 +78,15 @@ class MockTeleop(Teleoperator): def is_calibrated(self) -> bool: return self._is_calibrated + @check_if_not_connected def calibrate(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self._is_calibrated = True def configure(self) -> None: pass + @check_if_not_connected def get_action(self) -> RobotAction: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.config.random_values: return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors} else: @@ -100,12 +94,9 @@ class MockTeleop(Teleoperator): f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True) } - def send_feedback(self, feedback: dict[str, Any]) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") + @check_if_not_connected + def send_feedback(self, feedback: dict[str, Any]) -> None: ... + @check_if_not_connected def disconnect(self) -> None: - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - self._is_connected = False From da41646073d288f01f21fdd602b2281864ed06f6 Mon Sep 17 00:00:00 2001 From: Sung-Wook Lee <60949213+sean1295@users.noreply.github.com> Date: Mon, 19 Jan 2026 07:18:52 -0500 Subject: [PATCH 048/111] fix libero reset logic for correct resetting (#2817) --- src/lerobot/envs/libero.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index 74882ad18..96c5cf102 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -293,9 +293,9 @@ class LiberoEnv(gym.Env): def reset(self, seed=None, **kwargs): super().reset(seed=seed) self._env.seed(seed) - if self.init_states and self._init_states is not None: - self._env.set_init_state(self._init_states[self._init_state_id]) raw_obs = self._env.reset() + if self.init_states and self._init_states is not None: + raw_obs = self._env.set_init_state(self._init_states[self._init_state_id]) # After reset, objects may be unstable (slightly floating, intersecting, etc.). # Step the simulator with a no-op action for a few frames so everything settles. From fe068df711dd5d08aa04c577b40d31ec61245f22 Mon Sep 17 00:00:00 2001 From: bigmbigk Date: Mon, 19 Jan 2026 22:14:10 +0900 Subject: [PATCH 049/111] fix(train): eval env initialization on train script (#2818) * fix: eval env initialization on train script Signed-off-by: bigmbigk * fix: eval env creation condition --------- Signed-off-by: bigmbigk --- src/lerobot/scripts/lerobot_train.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 41f866ce8..93b99e245 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -225,9 +225,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): # On real-world data, no need to create an environment as evaluations are done outside train.py, # using the eval.py instead, with gym_dora environment and dora-rs. eval_env = None - if cfg.eval_freq > 0 and cfg.env is not None: - if is_main_process: - logging.info("Creating env") + if cfg.eval_freq > 0 and cfg.env is not None and is_main_process: + logging.info("Creating env") eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) if is_main_process: From 5286ef8439efe95ba3f98f3d7b3c1bdf4898c347 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 19 Jan 2026 16:43:11 +0100 Subject: [PATCH 050/111] feat(utils): extend import check util (#2820) * refactor(utils): is_package_available now differentiate between pkg name and module name * refactor(tests): update require_package decorator --- src/lerobot/utils/import_utils.py | 26 ++++++--- tests/async_inference/test_policy_server.py | 2 +- tests/rl/test_actor.py | 10 ++-- tests/rl/test_actor_learner.py | 6 +- tests/rl/test_learner_service.py | 16 +++--- tests/transport/test_transport_utils.py | 62 ++++++++++----------- tests/utils.py | 4 +- 7 files changed, 67 insertions(+), 59 deletions(-) diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index 0206a8ac9..a499b96c7 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -21,12 +21,23 @@ from typing import Any from draccus.choice_types import ChoiceRegistry -def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool: - """Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py - Check if the package spec exists and grab its version to avoid importing a local directory. - **Note:** this doesn't work for all packages. +def is_package_available( + pkg_name: str, import_name: str | None = None, return_version: bool = False +) -> tuple[bool, str] | bool: """ - package_exists = importlib.util.find_spec(pkg_name) is not None + Check if the package spec exists and grab its version to avoid importing a local directory. + + Args: + pkg_name: The name of the package as installed via pip (e.g. "python-can"). + import_name: The actual name used to import the package (e.g. "can"). + Defaults to pkg_name if not provided. + return_version: Whether to return the version string. + """ + if import_name is None: + import_name = pkg_name + + # Check if the module spec exists using the import name + package_exists = importlib.util.find_spec(import_name) is not None package_version = "N/A" if package_exists: try: @@ -37,7 +48,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b # Fallback method: Only for "torch" and versions containing "dev" if pkg_name == "torch": try: - package = importlib.import_module(pkg_name) + package = importlib.import_module(import_name) temp_version = getattr(package, "__version__", "N/A") # Check if the version contains "dev" if "dev" in temp_version: @@ -48,9 +59,6 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b except ImportError: # If the package can't be imported, it's not available package_exists = False - elif pkg_name == "grpc": - package = importlib.import_module(pkg_name) - package_version = getattr(package, "__version__", "N/A") else: # For packages other than "torch", don't attempt the fallback and set as not available package_exists = False diff --git a/tests/async_inference/test_policy_server.py b/tests/async_inference/test_policy_server.py index 29583d4fa..c3ee37c8f 100644 --- a/tests/async_inference/test_policy_server.py +++ b/tests/async_inference/test_policy_server.py @@ -62,7 +62,7 @@ class MockPolicy: @pytest.fixture -@require_package("grpc") +@require_package("grpcio", "grpc") def policy_server(): """Fresh `PolicyServer` instance with a stubbed-out policy model.""" # Import only when the test actually runs (after decorator check) diff --git a/tests/rl/test_actor.py b/tests/rl/test_actor.py index ec67f1889..54e4d2870 100644 --- a/tests/rl/test_actor.py +++ b/tests/rl/test_actor.py @@ -64,7 +64,7 @@ def close_service_stub(channel, server): server.stop(None) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_establish_learner_connection_success(): from lerobot.rl.actor import establish_learner_connection @@ -81,7 +81,7 @@ def test_establish_learner_connection_success(): close_service_stub(channel, server) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_establish_learner_connection_failure(): from lerobot.rl.actor import establish_learner_connection @@ -100,7 +100,7 @@ def test_establish_learner_connection_failure(): close_service_stub(channel, server) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_push_transitions_to_transport_queue(): from lerobot.rl.actor import push_transitions_to_transport_queue from lerobot.transport.utils import bytes_to_transitions @@ -135,7 +135,7 @@ def test_push_transitions_to_transport_queue(): assert_transitions_equal(deserialized_transition, transitions[i]) -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_transitions_stream(): from lerobot.rl.actor import transitions_stream @@ -167,7 +167,7 @@ def test_transitions_stream(): assert streamed_data[2].data == b"transition_data_3" -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_interactions_stream(): from lerobot.rl.actor import interactions_stream diff --git a/tests/rl/test_actor_learner.py b/tests/rl/test_actor_learner.py index 5d95dee04..e13862d82 100644 --- a/tests/rl/test_actor_learner.py +++ b/tests/rl/test_actor_learner.py @@ -88,7 +88,7 @@ def cfg(): return cfg -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(10) # force cross-platform watchdog def test_end_to_end_transitions_flow(cfg): from lerobot.rl.actor import ( @@ -150,7 +150,7 @@ def test_end_to_end_transitions_flow(cfg): assert_transitions_equal(transition, input_transitions[i]) -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(10) def test_end_to_end_interactions_flow(cfg): from lerobot.rl.actor import ( @@ -223,7 +223,7 @@ def test_end_to_end_interactions_flow(cfg): assert received == expected -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.parametrize("data_size", ["small", "large"]) @pytest.mark.timeout(10) def test_end_to_end_parameters_flow(cfg, data_size): diff --git a/tests/rl/test_learner_service.py b/tests/rl/test_learner_service.py index e0f0292be..d967388f0 100644 --- a/tests/rl/test_learner_service.py +++ b/tests/rl/test_learner_service.py @@ -39,7 +39,7 @@ def learner_service_stub(): close_learner_service_stub(channel, server) -@require_package("grpc") +@require_package("grpcio", "grpc") def create_learner_service_stub( shutdown_event: Event, parameters_queue: Queue, @@ -75,7 +75,7 @@ def create_learner_service_stub( return services_pb2_grpc.LearnerServiceStub(channel), channel, server -@require_package("grpc") +@require_package("grpcio", "grpc") def close_learner_service_stub(channel, server): channel.close() server.stop(None) @@ -91,7 +91,7 @@ def test_ready_method(learner_service_stub): assert response == services_pb2.Empty() -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_send_interactions(): from lerobot.transport import services_pb2 @@ -135,7 +135,7 @@ def test_send_interactions(): assert interactions == [b"123", b"4", b"5", b"678"] -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_send_transitions(): from lerobot.transport import services_pb2 @@ -181,7 +181,7 @@ def test_send_transitions(): assert transitions == [b"transition_1transition_2transition_3", b"batch_1batch_2"] -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_send_transitions_empty_stream(): from lerobot.transport import services_pb2 @@ -209,7 +209,7 @@ def test_send_transitions_empty_stream(): assert transitions_queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(10) # force cross-platform watchdog def test_stream_parameters(): import time @@ -267,7 +267,7 @@ def test_stream_parameters(): assert time_diff == pytest.approx(seconds_between_pushes, abs=0.1) -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_stream_parameters_with_shutdown(): from lerobot.transport import services_pb2 @@ -319,7 +319,7 @@ def test_stream_parameters_with_shutdown(): assert received_params == [b"param_batch_1", b"stop"] -@require_package("grpc") +@require_package("grpcio", "grpc") @pytest.mark.timeout(3) # force cross-platform watchdog def test_stream_parameters_waits_and_retries_on_empty_queue(): import threading diff --git a/tests/transport/test_transport_utils.py b/tests/transport/test_transport_utils.py index 52825a24e..63632a8f4 100644 --- a/tests/transport/test_transport_utils.py +++ b/tests/transport/test_transport_utils.py @@ -26,7 +26,7 @@ from lerobot.utils.transition import Transition from tests.utils import require_cuda, require_package -@require_package("grpc") +@require_package("grpcio", "grpc") def test_bytes_buffer_size_empty_buffer(): from lerobot.transport.utils import bytes_buffer_size @@ -37,7 +37,7 @@ def test_bytes_buffer_size_empty_buffer(): assert buffer.tell() == 0 -@require_package("grpc") +@require_package("grpcio", "grpc") def test_bytes_buffer_size_small_buffer(): from lerobot.transport.utils import bytes_buffer_size @@ -47,7 +47,7 @@ def test_bytes_buffer_size_small_buffer(): assert buffer.tell() == 0 -@require_package("grpc") +@require_package("grpcio", "grpc") def test_bytes_buffer_size_large_buffer(): from lerobot.transport.utils import CHUNK_SIZE, bytes_buffer_size @@ -58,7 +58,7 @@ def test_bytes_buffer_size_large_buffer(): assert buffer.tell() == 0 -@require_package("grpc") +@require_package("grpcio", "grpc") def test_send_bytes_in_chunks_empty_data(): from lerobot.transport.utils import send_bytes_in_chunks, services_pb2 @@ -68,7 +68,7 @@ def test_send_bytes_in_chunks_empty_data(): assert len(chunks) == 0 -@require_package("grpc") +@require_package("grpcio", "grpc") def test_single_chunk_small_data(): from lerobot.transport.utils import send_bytes_in_chunks, services_pb2 @@ -82,7 +82,7 @@ def test_single_chunk_small_data(): assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END -@require_package("grpc") +@require_package("grpcio", "grpc") def test_not_silent_mode(): from lerobot.transport.utils import send_bytes_in_chunks, services_pb2 @@ -94,7 +94,7 @@ def test_not_silent_mode(): assert chunks[0].data == b"Some data" -@require_package("grpc") +@require_package("grpcio", "grpc") def test_send_bytes_in_chunks_large_data(): from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2 @@ -111,7 +111,7 @@ def test_send_bytes_in_chunks_large_data(): assert chunks[2].transfer_state == services_pb2.TransferState.TRANSFER_END -@require_package("grpc") +@require_package("grpcio", "grpc") def test_send_bytes_in_chunks_large_data_with_exact_chunk_size(): from lerobot.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2 @@ -124,7 +124,7 @@ def test_send_bytes_in_chunks_large_data_with_exact_chunk_size(): assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_empty_data(): from lerobot.transport.utils import receive_bytes_in_chunks @@ -138,7 +138,7 @@ def test_receive_bytes_in_chunks_empty_data(): assert queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_single_chunk(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -157,7 +157,7 @@ def test_receive_bytes_in_chunks_single_chunk(): assert queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_single_not_end_chunk(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -175,7 +175,7 @@ def test_receive_bytes_in_chunks_single_not_end_chunk(): assert queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_multiple_chunks(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -199,7 +199,7 @@ def test_receive_bytes_in_chunks_multiple_chunks(): assert queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_multiple_messages(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -235,7 +235,7 @@ def test_receive_bytes_in_chunks_multiple_messages(): assert queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_shutdown_during_receive(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -259,7 +259,7 @@ def test_receive_bytes_in_chunks_shutdown_during_receive(): assert queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_only_begin_chunk(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -279,7 +279,7 @@ def test_receive_bytes_in_chunks_only_begin_chunk(): assert queue.empty() -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_missing_begin(): from lerobot.transport.utils import receive_bytes_in_chunks, services_pb2 @@ -303,7 +303,7 @@ def test_receive_bytes_in_chunks_missing_begin(): # Tests for state_to_bytes and bytes_to_state_dict -@require_package("grpc") +@require_package("grpcio", "grpc") def test_state_to_bytes_empty_dict(): from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes @@ -314,7 +314,7 @@ def test_state_to_bytes_empty_dict(): assert reconstructed == state_dict -@require_package("grpc") +@require_package("grpcio", "grpc") def test_bytes_to_state_dict_empty_data(): from lerobot.transport.utils import bytes_to_state_dict @@ -323,7 +323,7 @@ def test_bytes_to_state_dict_empty_data(): bytes_to_state_dict(b"") -@require_package("grpc") +@require_package("grpcio", "grpc") def test_state_to_bytes_simple_dict(): from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes @@ -347,7 +347,7 @@ def test_state_to_bytes_simple_dict(): assert torch.allclose(state_dict[key], reconstructed[key]) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_state_to_bytes_various_dtypes(): from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes @@ -372,7 +372,7 @@ def test_state_to_bytes_various_dtypes(): assert torch.allclose(state_dict[key], reconstructed[key]) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_bytes_to_state_dict_invalid_data(): from lerobot.transport.utils import bytes_to_state_dict @@ -382,7 +382,7 @@ def test_bytes_to_state_dict_invalid_data(): @require_cuda -@require_package("grpc") +@require_package("grpcio", "grpc") def test_state_to_bytes_various_dtypes_cuda(): from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes @@ -407,7 +407,7 @@ def test_state_to_bytes_various_dtypes_cuda(): assert torch.allclose(state_dict[key], reconstructed[key]) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_python_object_to_bytes_none(): from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes @@ -439,7 +439,7 @@ def test_python_object_to_bytes_none(): (1, 2, 3), ], ) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_python_object_to_bytes_simple_types(obj): from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes @@ -450,7 +450,7 @@ def test_python_object_to_bytes_simple_types(obj): assert type(reconstructed) is type(obj) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_python_object_to_bytes_with_tensors(): from lerobot.transport.utils import bytes_to_python_object, python_object_to_bytes @@ -475,7 +475,7 @@ def test_python_object_to_bytes_with_tensors(): assert torch.equal(obj["nested"]["tensor2"], reconstructed["nested"]["tensor2"]) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_transitions_to_bytes_empty_list(): from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes @@ -487,7 +487,7 @@ def test_transitions_to_bytes_empty_list(): assert isinstance(reconstructed, list) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_transitions_to_bytes_single_transition(): from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes @@ -509,7 +509,7 @@ def test_transitions_to_bytes_single_transition(): assert_transitions_equal(transitions[0], reconstructed[0]) -@require_package("grpc") +@require_package("grpcio", "grpc") def assert_transitions_equal(t1: Transition, t2: Transition): """Helper to assert two transitions are equal.""" assert_observation_equal(t1["state"], t2["state"]) @@ -519,7 +519,7 @@ def assert_transitions_equal(t1: Transition, t2: Transition): assert_observation_equal(t1["next_state"], t2["next_state"]) -@require_package("grpc") +@require_package("grpcio", "grpc") def assert_observation_equal(o1: dict, o2: dict): """Helper to assert two observations are equal.""" assert set(o1.keys()) == set(o2.keys()) @@ -527,7 +527,7 @@ def assert_observation_equal(o1: dict, o2: dict): assert torch.allclose(o1[key], o2[key]) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_transitions_to_bytes_multiple_transitions(): from lerobot.transport.utils import bytes_to_transitions, transitions_to_bytes @@ -551,7 +551,7 @@ def test_transitions_to_bytes_multiple_transitions(): assert_transitions_equal(original, reconstructed_item) -@require_package("grpc") +@require_package("grpcio", "grpc") def test_receive_bytes_in_chunks_unknown_state(): from lerobot.transport.utils import receive_bytes_in_chunks diff --git a/tests/utils.py b/tests/utils.py index 800b7d4b3..38841db02 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -167,7 +167,7 @@ def require_package_arg(func): return wrapper -def require_package(package_name): +def require_package(package_name, import_name=None): """ Decorator that skips the test if the specified package is not installed. """ @@ -175,7 +175,7 @@ def require_package(package_name): def decorator(func): @wraps(func) def wrapper(*args, **kwargs): - if not is_package_available(package_name): + if not is_package_available(pkg_name=package_name, import_name=import_name): pytest.skip(f"{package_name} not installed") return func(*args, **kwargs) From 66929c5935cba96d04e3f5d1c5baf2bd7afde783 Mon Sep 17 00:00:00 2001 From: Maximilian Ofir <164489172+maximilianofir@users.noreply.github.com> Date: Mon, 19 Jan 2026 22:13:48 +0100 Subject: [PATCH 051/111] feat: add async server-client streaming support for Groot policy (#2812) --- src/lerobot/async_inference/constants.py | 2 +- src/lerobot/policies/groot/modeling_groot.py | 131 ++++++++++++++++++- 2 files changed, 131 insertions(+), 2 deletions(-) diff --git a/src/lerobot/async_inference/constants.py b/src/lerobot/async_inference/constants.py index 081db0504..56910e67f 100644 --- a/src/lerobot/async_inference/constants.py +++ b/src/lerobot/async_inference/constants.py @@ -23,7 +23,7 @@ DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS DEFAULT_OBS_QUEUE_TIMEOUT = 2 # All action chunking policies -SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"] +SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05", "groot"] # TODO: Add all other robots SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so_follower", "omx_follower"] diff --git a/src/lerobot/policies/groot/modeling_groot.py b/src/lerobot/policies/groot/modeling_groot.py index fd9baa9b1..9a479b8f9 100644 --- a/src/lerobot/policies/groot/modeling_groot.py +++ b/src/lerobot/policies/groot/modeling_groot.py @@ -32,16 +32,22 @@ Notes: from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below. """ +import builtins import os from collections import deque +from pathlib import Path +from typing import TypeVar import torch from torch import Tensor +from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.policies.groot.configuration_groot import GrootConfig from lerobot.policies.groot.groot_n1 import GR00TN15 from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.utils.constants import ACTION +from lerobot.utils.constants import ACTION, OBS_IMAGES + +T = TypeVar("T", bound="GrootPolicy") class GrootPolicy(PreTrainedPolicy): @@ -90,6 +96,129 @@ class GrootPolicy(PreTrainedPolicy): """Reset policy state when environment resets.""" self._action_queue = deque([], maxlen=self.config.n_action_steps) + @classmethod + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + config: GrootConfig | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + strict: bool = True, + **kwargs, + ) -> T: + """Load Groot policy from pretrained model. + + Handles two cases: + 1. Base GR00T models (e.g., 'nvidia/GR00T-N1.5-3B') - loads the raw model + 2. Fine-tuned LeRobot checkpoints - loads config and weights from safetensors + + Args: + pretrained_name_or_path: Path to the GR00T model or fine-tuned checkpoint + config: Optional GrootConfig. If None, loads from checkpoint or creates default + force_download: Force download even if cached + resume_download: Resume interrupted download + proxies: Proxy settings + token: HuggingFace authentication token + cache_dir: Cache directory path + local_files_only: Only use local files + revision: Specific model revision + strict: Strict state dict loading + **kwargs: Additional arguments (passed to config) + + Returns: + Initialized GrootPolicy instance with loaded model + """ + from huggingface_hub import hf_hub_download + from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE + from huggingface_hub.errors import HfHubHTTPError + + print( + "The Groot policy is a wrapper around Nvidia's GR00T N1.5 model.\n" + f"Loading pretrained model from: {pretrained_name_or_path}" + ) + + model_id = str(pretrained_name_or_path) + is_finetuned_checkpoint = False + + # Check if this is a fine-tuned LeRobot checkpoint (has model.safetensors) + try: + if os.path.isdir(model_id): + is_finetuned_checkpoint = os.path.exists(os.path.join(model_id, SAFETENSORS_SINGLE_FILE)) + else: + # Try to download the safetensors file to check if it exists + try: + hf_hub_download( + repo_id=model_id, + filename=SAFETENSORS_SINGLE_FILE, + revision=revision, + cache_dir=cache_dir, + force_download=False, # Just check, don't force download + proxies=proxies, + token=token, + local_files_only=local_files_only, + ) + is_finetuned_checkpoint = True + except HfHubHTTPError: + is_finetuned_checkpoint = False + except Exception: + is_finetuned_checkpoint = False + + if is_finetuned_checkpoint: + # This is a fine-tuned LeRobot checkpoint - use parent class loading + print("Detected fine-tuned LeRobot checkpoint, loading with state dict...") + return super().from_pretrained( + pretrained_name_or_path=pretrained_name_or_path, + config=config, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + strict=strict, + **kwargs, + ) + + # This is a base GR00T model - load it fresh + print("Detected base GR00T model, loading from HuggingFace...") + + if config is None: + # Create default config with the pretrained path + config = GrootConfig(base_model_path=str(pretrained_name_or_path)) + + # Add minimal visual feature required for validation + # validate_features() will automatically add state and action features + # These are placeholders - actual robot features come from the preprocessor + if not config.input_features: + config.input_features = { + f"{OBS_IMAGES}.camera": PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 224, 224), # Default image size from config + ), + } + else: + # Override the base_model_path with the provided path + config.base_model_path = str(pretrained_name_or_path) + + # Pass through any additional config overrides from kwargs + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + + # Create a fresh policy instance - this will automatically load the GR00T model + # in __init__ via _create_groot_model() + policy = cls(config) + + policy.eval() + return policy + def get_optim_params(self) -> dict: return self.parameters() From b2ff21962478eea53c2db50333302fda3cb19b7f Mon Sep 17 00:00:00 2001 From: Francesco Capuano <74058581+fracapuano@users.noreply.github.com> Date: Mon, 19 Jan 2026 22:36:41 +0000 Subject: [PATCH 052/111] Fixes aggregation of image datasets (#2717) * fix: use features when aggregating image based datasets * add: test asserting for data type * add: features param to writing dataset --------- Co-authored-by: Steven Palma --- src/lerobot/datasets/aggregate.py | 29 ++++-- src/lerobot/datasets/utils.py | 13 ++- tests/datasets/test_aggregate.py | 145 ++++++++++++++++++++++++++++++ 3 files changed, 180 insertions(+), 7 deletions(-) diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index 455caf0fe..94ffe602e 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -19,6 +19,7 @@ import logging import shutil from pathlib import Path +import datasets import pandas as pd import tqdm @@ -32,6 +33,7 @@ from lerobot.datasets.utils import ( DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_VIDEO_PATH, get_file_size_in_mb, + get_hf_features_from_features, get_parquet_file_size_in_mb, to_parquet_with_hf_images, update_chunk_file_indices, @@ -402,12 +404,21 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si } unique_chunk_file_ids = sorted(unique_chunk_file_ids) + contains_images = len(dst_meta.image_keys) > 0 + + # retrieve features schema for proper image typing in parquet + hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None for src_chunk_idx, src_file_idx in unique_chunk_file_ids: src_path = src_meta.root / DEFAULT_DATA_PATH.format( chunk_index=src_chunk_idx, file_index=src_file_idx ) - df = pd.read_parquet(src_path) + if contains_images: + # Use HuggingFace datasets to read source data to preserve image format + src_ds = datasets.Dataset.from_parquet(str(src_path)) + df = src_ds.to_pandas() + else: + df = pd.read_parquet(src_path) df = update_data_df(df, src_meta, dst_meta) data_idx = append_or_create_parquet_file( @@ -417,8 +428,9 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si data_files_size_in_mb, chunk_size, DEFAULT_DATA_PATH, - contains_images=len(dst_meta.image_keys) > 0, + contains_images=contains_images, aggr_root=dst_meta.root, + hf_features=hf_features, ) return data_idx @@ -488,6 +500,7 @@ def append_or_create_parquet_file( default_path: str, contains_images: bool = False, aggr_root: Path = None, + hf_features: datasets.Features | None = None, ): """Appends data to an existing parquet file or creates a new one based on size constraints. @@ -503,6 +516,7 @@ def append_or_create_parquet_file( default_path: Format string for generating file paths. contains_images: Whether the data contains images requiring special handling. aggr_root: Root path for the aggregated dataset. + hf_features: Optional HuggingFace Features schema for proper image typing. Returns: dict: Updated index dictionary with current chunk and file indices. @@ -512,7 +526,7 @@ def append_or_create_parquet_file( if not dst_path.exists(): dst_path.parent.mkdir(parents=True, exist_ok=True) if contains_images: - to_parquet_with_hf_images(df, dst_path) + to_parquet_with_hf_images(df, dst_path, features=hf_features) else: df.to_parquet(dst_path) return idx @@ -527,12 +541,17 @@ def append_or_create_parquet_file( final_df = df target_path = new_path else: - existing_df = pd.read_parquet(dst_path) + if contains_images: + # Use HuggingFace datasets to read existing data to preserve image format + existing_ds = datasets.Dataset.from_parquet(str(dst_path)) + existing_df = existing_ds.to_pandas() + else: + existing_df = pd.read_parquet(dst_path) final_df = pd.concat([existing_df, df], ignore_index=True) target_path = dst_path if contains_images: - to_parquet_with_hf_images(final_df, target_path) + to_parquet_with_hf_images(final_df, target_path, features=hf_features) else: final_df.to_parquet(target_path) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 234736a75..ed678af6e 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -1172,12 +1172,21 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features: ) -def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path) -> None: +def to_parquet_with_hf_images( + df: pandas.DataFrame, path: Path, features: datasets.Features | None = None +) -> None: """This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset. This way, it can be loaded by HF dataset and correctly formatted images are returned. + + Args: + df: DataFrame to write to parquet. + path: Path to write the parquet file. + features: Optional HuggingFace Features schema. If provided, ensures image columns + are properly typed as Image() in the parquet schema. """ # TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only - datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path) + ds = datasets.Dataset.from_dict(df.to_dict(orient="list"), features=features) + ds.to_parquet(path) def item_to_torch(item: dict) -> dict: diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index b710a3a4b..031c29d60 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -16,6 +16,7 @@ from unittest.mock import patch +import datasets import torch from lerobot.datasets.aggregate import aggregate_datasets @@ -380,3 +381,147 @@ def test_video_timestamps_regression(tmp_path, lerobot_dataset_factory): for key in aggr_ds.meta.video_keys: assert key in item, f"Video key {key} missing from item {i}" assert item[key].shape[0] == 3, f"Expected 3 channels for video key {key}" + + +def assert_image_schema_preserved(aggr_ds): + """Test that HuggingFace Image feature schema is preserved in aggregated parquet files. + + This verifies the fix for a bug where image columns were written with a generic + struct schema {'bytes': Value('binary'), 'path': Value('string')} instead of + the proper Image() feature type, causing HuggingFace Hub viewer to display + raw dict objects instead of image thumbnails. + """ + image_keys = aggr_ds.meta.image_keys + if not image_keys: + return + + # Check that parquet files have proper Image schema + data_dir = aggr_ds.root / "data" + parquet_files = list(data_dir.rglob("*.parquet")) + assert len(parquet_files) > 0, "No parquet files found in aggregated dataset" + + for parquet_file in parquet_files: + # Load with HuggingFace datasets to check schema + ds = datasets.Dataset.from_parquet(str(parquet_file)) + + for image_key in image_keys: + feature = ds.features.get(image_key) + assert feature is not None, f"Image key '{image_key}' not found in parquet schema" + assert isinstance(feature, datasets.Image), ( + f"Image key '{image_key}' should have Image() feature type, " + f"but got {type(feature).__name__}: {feature}. " + "This indicates image schema was not preserved during aggregation." + ) + + +def assert_image_frames_integrity(aggr_ds, ds_0, ds_1): + """Test that image frames are correctly preserved after aggregation.""" + image_keys = aggr_ds.meta.image_keys + if not image_keys: + return + + def images_equal(img1, img2): + return torch.allclose(img1, img2) + + # Test the section corresponding to the first dataset (ds_0) + for i in range(len(ds_0)): + assert aggr_ds[i]["index"] == i, ( + f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}" + ) + for key in image_keys: + assert images_equal(aggr_ds[i][key], ds_0[i][key]), ( + f"Image frames at position {i} should be equal between aggregated and ds_0" + ) + + # Test the section corresponding to the second dataset (ds_1) + for i in range(len(ds_0), len(ds_0) + len(ds_1)): + assert aggr_ds[i]["index"] == i, ( + f"Frame index at position {i} should be {i}, but got {aggr_ds[i]['index']}" + ) + for key in image_keys: + assert images_equal(aggr_ds[i][key], ds_1[i - len(ds_0)][key]), ( + f"Image frames at position {i} should be equal between aggregated and ds_1" + ) + + +def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory): + """Test aggregation of image-based datasets preserves HuggingFace Image schema. + + This test specifically verifies that: + 1. Image-based datasets can be aggregated correctly + 2. The HuggingFace Image() feature type is preserved in parquet files + 3. Image data integrity is maintained across aggregation + 4. Images can be properly decoded after aggregation + + This catches the bug where to_parquet_with_hf_images() was not passing + the features schema, causing image columns to be written as generic + struct types instead of Image() types. + """ + ds_0_num_frames = 50 + ds_1_num_frames = 75 + ds_0_num_episodes = 2 + ds_1_num_episodes = 3 + + # Create two image-based datasets (use_videos=False) + ds_0 = lerobot_dataset_factory( + root=tmp_path / "image_0", + repo_id=f"{DUMMY_REPO_ID}_image_0", + total_episodes=ds_0_num_episodes, + total_frames=ds_0_num_frames, + use_videos=False, # Image-based dataset + ) + ds_1 = lerobot_dataset_factory( + root=tmp_path / "image_1", + repo_id=f"{DUMMY_REPO_ID}_image_1", + total_episodes=ds_1_num_episodes, + total_frames=ds_1_num_frames, + use_videos=False, # Image-based dataset + ) + + # Verify source datasets have image keys + assert len(ds_0.meta.image_keys) > 0, "ds_0 should have image keys" + assert len(ds_1.meta.image_keys) > 0, "ds_1 should have image keys" + + # Aggregate the datasets + aggregate_datasets( + repo_ids=[ds_0.repo_id, ds_1.repo_id], + roots=[ds_0.root, ds_1.root], + aggr_repo_id=f"{DUMMY_REPO_ID}_image_aggr", + aggr_root=tmp_path / "image_aggr", + ) + + # Load the aggregated dataset + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "image_aggr") + aggr_ds = LeRobotDataset(f"{DUMMY_REPO_ID}_image_aggr", root=tmp_path / "image_aggr") + + # Verify aggregated dataset has image keys + assert len(aggr_ds.meta.image_keys) > 0, "Aggregated dataset should have image keys" + assert aggr_ds.meta.image_keys == ds_0.meta.image_keys, "Image keys should match source datasets" + + # Run standard aggregation assertions + expected_total_episodes = ds_0_num_episodes + ds_1_num_episodes + expected_total_frames = ds_0_num_frames + ds_1_num_frames + + assert_episode_and_frame_counts(aggr_ds, expected_total_episodes, expected_total_frames) + assert_dataset_content_integrity(aggr_ds, ds_0, ds_1) + assert_metadata_consistency(aggr_ds, ds_0, ds_1) + assert_episode_indices_updated_correctly(aggr_ds, ds_0, ds_1) + + # Image-specific assertions + assert_image_schema_preserved(aggr_ds) + assert_image_frames_integrity(aggr_ds, ds_0, ds_1) + + # Verify images can be accessed and have correct shape + sample_item = aggr_ds[0] + for image_key in aggr_ds.meta.image_keys: + img = sample_item[image_key] + assert isinstance(img, torch.Tensor), f"Image {image_key} should be a tensor" + assert img.dim() == 3, f"Image {image_key} should have 3 dimensions (C, H, W)" + assert img.shape[0] == 3, f"Image {image_key} should have 3 channels" + + assert_dataset_iteration_works(aggr_ds) From 79688a09f2e88f830fb6fdf57bbcda4b35acc36f Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Tue, 20 Jan 2026 11:04:22 +0100 Subject: [PATCH 053/111] improve(dataset-tools): image2video editing tools : Multiple episodes per video file (#2811) * improve image2video * add episodes video encoding * fix mypy failing * iterate on review * nit * remove max, and let it be optional * iterate more * update docs * fix test --------- Co-authored-by: Michel Aractingi --- docs/source/using_dataset_tools.mdx | 19 +- src/lerobot/datasets/dataset_tools.py | 562 +++++++++++++++++++- src/lerobot/scripts/lerobot_edit_dataset.py | 411 ++------------ tests/datasets/test_dataset_tools.py | 10 +- 4 files changed, 611 insertions(+), 391 deletions(-) diff --git a/docs/source/using_dataset_tools.mdx b/docs/source/using_dataset_tools.mdx index 29e16ea0a..9e662604e 100644 --- a/docs/source/using_dataset_tools.mdx +++ b/docs/source/using_dataset_tools.mdx @@ -95,26 +95,26 @@ Convert an image-based dataset to video format, creating a new LeRobotDataset wh # Local-only: Save to a custom output directory (no hub push) lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ - --operation.type convert_to_video \ + --operation.type convert_image_to_video \ --operation.output_dir /path/to/output/pusht_video # Save with new repo_id (local storage) lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ --new_repo_id lerobot/pusht_video \ - --operation.type convert_to_video + --operation.type convert_image_to_video # Convert and push to Hugging Face Hub lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ --new_repo_id lerobot/pusht_video \ - --operation.type convert_to_video \ + --operation.type convert_image_to_video \ --push_to_hub true # Convert with custom video codec and quality settings lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ - --operation.type convert_to_video \ + --operation.type convert_image_to_video \ --operation.output_dir outputs/pusht_video \ --operation.vcodec libsvtav1 \ --operation.pix_fmt yuv420p \ @@ -124,16 +124,23 @@ lerobot-edit-dataset \ # Convert only specific episodes lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ - --operation.type convert_to_video \ + --operation.type convert_image_to_video \ --operation.output_dir outputs/pusht_video \ --operation.episode_indices "[0, 1, 2, 5, 10]" # Convert with multiple workers for parallel processing lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ - --operation.type convert_to_video \ + --operation.type convert_image_to_video \ --operation.output_dir outputs/pusht_video \ --operation.num_workers 8 + +# For memory-constrained systems, users can now specify limits: +lerobot-edit-dataset \ + --repo_id lerobot/pusht_image \ + --operation.type convert_to_video \ + --operation.max_episodes_per_batch 50 \ + --operation.max_frames_per_batch 10000 ``` **Parameters:** diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index 2fb68dca1..e2928e2a6 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -26,6 +26,7 @@ This module provides utilities for: import logging import shutil from collections.abc import Callable +from concurrent.futures import ThreadPoolExecutor, as_completed from pathlib import Path import datasets @@ -51,7 +52,8 @@ from lerobot.datasets.utils import ( write_stats, write_tasks, ) -from lerobot.utils.constants import HF_LEROBOT_HOME +from lerobot.datasets.video_utils import encode_video_frames, get_video_info +from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE def _load_episode_with_stats(src_dataset: LeRobotDataset, episode_idx: int) -> dict: @@ -1083,3 +1085,561 @@ def _copy_episodes_metadata_and_stats( else: if src_dataset.meta.stats: write_stats(src_dataset.meta.stats, dst_meta.root) + + +def _save_episode_images_for_video( + dataset: LeRobotDataset, + imgs_dir: Path, + img_key: str, + episode_index: int, + num_workers: int = 4, +) -> None: + """Save images from a specific episode and camera to disk for video encoding. + + Args: + dataset: The LeRobot dataset to extract images from + imgs_dir: Directory to save images to + img_key: The image key (camera) to extract + episode_index: Index of the episode to save + num_workers: Number of threads for parallel image saving + """ + # Create directory + imgs_dir.mkdir(parents=True, exist_ok=True) + + # Get dataset without torch format for PIL image access + hf_dataset = dataset.hf_dataset.with_format(None) + + # Select only this camera's images + imgs_dataset = hf_dataset.select_columns(img_key) + + # Get episode start and end indices + from_idx = dataset.meta.episodes["dataset_from_index"][episode_index] + to_idx = dataset.meta.episodes["dataset_to_index"][episode_index] + + # Get all items for this episode + episode_dataset = imgs_dataset.select(range(from_idx, to_idx)) + + # Define function to save a single image + def save_single_image(i_item_tuple): + i, item = i_item_tuple + img = item[img_key] + # Use frame-XXXXXX.png format to match encode_video_frames expectations + img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100) + return i + + # Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png) + items = list(enumerate(episode_dataset)) + + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [executor.submit(save_single_image, item) for item in items] + for future in as_completed(futures): + future.result() # This will raise any exceptions that occurred + + +def _save_batch_episodes_images( + dataset: LeRobotDataset, + imgs_dir: Path, + img_key: str, + episode_indices: list[int], + num_workers: int = 4, +) -> list[float]: + """Save images from multiple episodes to disk for batch video encoding. + + Args: + dataset: The LeRobot dataset to extract images from + imgs_dir: Directory to save images to + img_key: The image key (camera) to extract + episode_indices: List of episode indices to save + num_workers: Number of threads for parallel image saving + + Returns: + List of episode durations in seconds + """ + imgs_dir.mkdir(parents=True, exist_ok=True) + hf_dataset = dataset.hf_dataset.with_format(None) + imgs_dataset = hf_dataset.select_columns(img_key) + + # Define function to save a single image with global frame index + # Defined once outside the loop to avoid repeated closure creation + def save_single_image(i_item_tuple, base_frame_idx, img_key_param): + i, item = i_item_tuple + img = item[img_key_param] + # Use global frame index for naming + img.save(str(imgs_dir / f"frame-{base_frame_idx + i:06d}.png"), quality=100) + return i + + episode_durations = [] + frame_idx = 0 + + for ep_idx in episode_indices: + # Get episode range + from_idx = dataset.meta.episodes["dataset_from_index"][ep_idx] + to_idx = dataset.meta.episodes["dataset_to_index"][ep_idx] + episode_length = to_idx - from_idx + episode_durations.append(episode_length / dataset.fps) + + # Get episode images + episode_dataset = imgs_dataset.select(range(from_idx, to_idx)) + + # Save images + items = list(enumerate(episode_dataset)) + with ThreadPoolExecutor(max_workers=num_workers) as executor: + futures = [executor.submit(save_single_image, item, frame_idx, img_key) for item in items] + for future in as_completed(futures): + future.result() + + frame_idx += episode_length + + return episode_durations + + +def _iter_episode_batches( + episode_indices: list[int], + episode_lengths: dict[int, int], + size_per_frame_mb: float, + video_file_size_limit: float, + max_episodes: int | None, + max_frames: int | None, +): + """Generator that yields batches of episode indices for video encoding. + + Groups episodes into batches that respect size and memory constraints: + - Stays under video file size limit + - Respects maximum episodes per batch (if specified) + - Respects maximum frames per batch (if specified) + + Args: + episode_indices: List of episode indices to batch + episode_lengths: Dictionary mapping episode index to episode length + size_per_frame_mb: Estimated size per frame in MB + video_file_size_limit: Maximum video file size in MB + max_episodes: Maximum number of episodes per batch (None = no limit) + max_frames: Maximum number of frames per batch (None = no limit) + + Yields: + List of episode indices for each batch + """ + batch_episodes = [] + estimated_size = 0.0 + total_frames = 0 + + for ep_idx in episode_indices: + ep_length = episode_lengths[ep_idx] + ep_estimated_size = ep_length * size_per_frame_mb + + # we check if adding this episode would exceed any constraint + would_exceed_size = estimated_size > 0 and estimated_size + ep_estimated_size >= video_file_size_limit + would_exceed_episodes = max_episodes is not None and len(batch_episodes) >= max_episodes + would_exceed_frames = max_frames is not None and total_frames + ep_length > max_frames + + if batch_episodes and (would_exceed_size or would_exceed_episodes or would_exceed_frames): + # yield current batch before adding this episode + yield batch_episodes + # start a new batch with current episode + batch_episodes = [ep_idx] + estimated_size = ep_estimated_size + total_frames = ep_length + else: + # add to current batch + batch_episodes.append(ep_idx) + estimated_size += ep_estimated_size + total_frames += ep_length + + # yield final batch if not empty + if batch_episodes: + yield batch_episodes + + +def _estimate_frame_size_via_calibration( + dataset: LeRobotDataset, + img_key: str, + episode_indices: list[int], + temp_dir: Path, + fps: int, + vcodec: str, + pix_fmt: str, + g: int, + crf: int, + fast_decode: int, + num_calibration_frames: int = 30, +) -> float: + """Estimate MB per frame by encoding a small calibration sample. + + Encodes a representative sample of frames using the exact codec parameters + to measure actual compression ratio, which is more accurate than heuristics. + + Args: + dataset: Source dataset with images. + img_key: Image key to calibrate (e.g., "observation.images.top"). + episode_indices: List of episode indices being processed. + temp_dir: Temporary directory for calibration files. + fps: Frames per second for video encoding. + vcodec: Video codec (libsvtav1, h264, hevc). + pix_fmt: Pixel format (yuv420p, etc.). + g: GOP size (group of pictures). + crf: Constant Rate Factor (quality). + fast_decode: Fast decode tuning parameter. + num_calibration_frames: Number of frames to use for calibration (default: 30). + + Returns: + Estimated size in MB per frame based on actual encoding. + """ + calibration_dir = temp_dir / "calibration" / img_key + calibration_dir.mkdir(parents=True, exist_ok=True) + + try: + # Select a representative episode (prefer middle episode if available) + calibration_ep_idx = episode_indices[len(episode_indices) // 2] + + # Get episode range + from_idx = dataset.meta.episodes["dataset_from_index"][calibration_ep_idx] + to_idx = dataset.meta.episodes["dataset_to_index"][calibration_ep_idx] + episode_length = to_idx - from_idx + + # Use up to num_calibration_frames from this episode + num_frames = min(num_calibration_frames, episode_length) + + # Get frames from dataset + hf_dataset = dataset.hf_dataset.with_format(None) + sample_indices = range(from_idx, from_idx + num_frames) + + # Save calibration frames + for i, idx in enumerate(sample_indices): + img = hf_dataset[idx][img_key] + img.save(str(calibration_dir / f"frame-{i:06d}.png"), quality=100) + + # Encode calibration video + calibration_video_path = calibration_dir / "calibration.mp4" + encode_video_frames( + imgs_dir=calibration_dir, + video_path=calibration_video_path, + fps=fps, + vcodec=vcodec, + pix_fmt=pix_fmt, + g=g, + crf=crf, + fast_decode=fast_decode, + overwrite=True, + ) + + # Measure actual compressed size + video_size_bytes = calibration_video_path.stat().st_size + video_size_mb = video_size_bytes / BYTES_PER_MIB + size_per_frame_mb = video_size_mb / num_frames + + logging.info( + f" Calibration: {num_frames} frames -> {video_size_mb:.2f} MB " + f"= {size_per_frame_mb:.4f} MB/frame for {img_key}" + ) + + return size_per_frame_mb + + finally: + # Clean up calibration files + if calibration_dir.exists(): + shutil.rmtree(calibration_dir) + + +def _copy_data_without_images( + src_dataset: LeRobotDataset, + dst_meta: LeRobotDatasetMetadata, + episode_indices: list[int], + img_keys: list[str], +) -> None: + """Copy data files without image columns. + + Args: + src_dataset: Source dataset + dst_meta: Destination metadata + episode_indices: Episodes to include + img_keys: Image keys to remove + """ + from lerobot.datasets.utils import DATA_DIR + + data_dir = src_dataset.root / DATA_DIR + parquet_files = sorted(data_dir.glob("*/*.parquet")) + + if not parquet_files: + raise ValueError(f"No parquet files found in {data_dir}") + + episode_set = set(episode_indices) + + for src_path in tqdm(parquet_files, desc="Processing data files"): + df = pd.read_parquet(src_path).reset_index(drop=True) + + # Filter to only include selected episodes + df = df[df["episode_index"].isin(episode_set)].copy() + + if len(df) == 0: + continue + + # Remove image columns + columns_to_drop = [col for col in img_keys if col in df.columns] + if columns_to_drop: + df = df.drop(columns=columns_to_drop) + + # Get chunk and file indices from path + relative_path = src_path.relative_to(src_dataset.root) + chunk_dir = relative_path.parts[1] + file_name = relative_path.parts[2] + chunk_idx = int(chunk_dir.split("-")[1]) + file_idx = int(file_name.split("-")[1].split(".")[0]) + + # Write to destination without pandas index + dst_path = dst_meta.root / f"data/chunk-{chunk_idx:03d}/file-{file_idx:03d}.parquet" + dst_path.parent.mkdir(parents=True, exist_ok=True) + df.to_parquet(dst_path, index=False) + + +# Video conversion constants +BYTES_PER_KIB = 1024 +BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB + + +def convert_image_to_video_dataset( + dataset: LeRobotDataset, + output_dir: Path, + repo_id: str | None = None, + vcodec: str = "libsvtav1", + pix_fmt: str = "yuv420p", + g: int = 2, + crf: int = 30, + fast_decode: int = 0, + episode_indices: list[int] | None = None, + num_workers: int = 4, + max_episodes_per_batch: int | None = None, + max_frames_per_batch: int | None = None, +) -> LeRobotDataset: + """Convert image-to-video dataset. + + Creates a new LeRobotDataset with images encoded as videos, following the proper + LeRobot dataset structure with videos stored in chunked MP4 files. + + Args: + dataset: The source LeRobot dataset with images + output_dir: Directory to save the new video dataset + repo_id: Repository ID for the new dataset (default: original_id + "_video") + vcodec: Video codec (default: libsvtav1) + pix_fmt: Pixel format (default: yuv420p) + g: Group of pictures size (default: 2) + crf: Constant rate factor (default: 30) + fast_decode: Fast decode tuning (default: 0) + episode_indices: List of episode indices to convert (None = all episodes) + num_workers: Number of threads for parallel processing (default: 4) + max_episodes_per_batch: Maximum episodes per video batch to avoid memory issues (None = no limit) + max_frames_per_batch: Maximum frames per video batch to avoid memory issues (None = no limit) + + Returns: + New LeRobotDataset with images encoded as videos + """ + # Check that it's an image dataset + if len(dataset.meta.video_keys) > 0: + raise ValueError( + f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}" + ) + + # Get all image keys + hf_dataset = dataset.hf_dataset.with_format(None) + img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)] + + if len(img_keys) == 0: + raise ValueError(f"No image keys found in dataset {dataset.repo_id}") + + # Determine which episodes to process + if episode_indices is None: + episode_indices = list(range(dataset.meta.total_episodes)) + + if repo_id is None: + repo_id = f"{dataset.repo_id}_video" + + logging.info( + f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}" + ) + logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}") + + # Create new features dict, converting image features to video features + new_features = {} + for key, value in dataset.meta.features.items(): + if key not in img_keys: + new_features[key] = value + else: + # Convert image key to video format + new_features[key] = value.copy() + new_features[key]["dtype"] = "video" # Change dtype from "image" to "video" + # Video info will be updated after episodes are encoded + + # Create new metadata for video dataset + new_meta = LeRobotDatasetMetadata.create( + repo_id=repo_id, + fps=dataset.meta.fps, + features=new_features, + robot_type=dataset.meta.robot_type, + root=output_dir, + use_videos=True, + chunks_size=dataset.meta.chunks_size, + data_files_size_in_mb=dataset.meta.data_files_size_in_mb, + video_files_size_in_mb=dataset.meta.video_files_size_in_mb, + ) + + # Create temporary directory for image extraction + temp_dir = output_dir / "temp_images" + temp_dir.mkdir(parents=True, exist_ok=True) + + # Process all episodes and batch encode videos + # Use dictionary for O(1) episode metadata lookups instead of O(n) linear search + all_episode_metadata = {} + fps = int(dataset.fps) + + try: + # Build episode metadata entries first + logging.info("Building episode metadata...") + cumulative_frame_idx = 0 + for ep_idx in episode_indices: + src_episode = dataset.meta.episodes[ep_idx] + ep_length = src_episode["length"] + ep_meta = { + "episode_index": ep_idx, + "length": ep_length, + "dataset_from_index": cumulative_frame_idx, + "dataset_to_index": cumulative_frame_idx + ep_length, + } + if "data/chunk_index" in src_episode: + ep_meta["data/chunk_index"] = src_episode["data/chunk_index"] + ep_meta["data/file_index"] = src_episode["data/file_index"] + all_episode_metadata[ep_idx] = ep_meta + cumulative_frame_idx += ep_length + + # Process each camera and batch encode multiple episodes together + video_file_size_limit = new_meta.video_files_size_in_mb + + # Pre-compute episode lengths for batching + episode_lengths = {ep_idx: dataset.meta.episodes["length"][ep_idx] for ep_idx in episode_indices} + + for img_key in tqdm(img_keys, desc="Processing cameras"): + # Estimate size per frame by encoding a small calibration sample + # This provides accurate compression ratio for the specific codec parameters + size_per_frame_mb = _estimate_frame_size_via_calibration( + dataset=dataset, + img_key=img_key, + episode_indices=episode_indices, + temp_dir=temp_dir, + fps=fps, + vcodec=vcodec, + pix_fmt=pix_fmt, + g=g, + crf=crf, + fast_decode=fast_decode, + ) + + logging.info(f"Processing camera: {img_key}") + chunk_idx, file_idx = 0, 0 + cumulative_timestamp = 0.0 + + # Process episodes in batches to stay under size limit + for batch_episodes in _iter_episode_batches( + episode_indices=episode_indices, + episode_lengths=episode_lengths, + size_per_frame_mb=size_per_frame_mb, + video_file_size_limit=video_file_size_limit, + max_episodes=max_episodes_per_batch, + max_frames=max_frames_per_batch, + ): + total_frames_in_batch = sum(episode_lengths[idx] for idx in batch_episodes) + logging.info( + f" Encoding batch of {len(batch_episodes)} episodes " + f"({batch_episodes[0]}-{batch_episodes[-1]}) = {total_frames_in_batch} frames" + ) + + # Save images for all episodes in this batch + imgs_dir = temp_dir / f"batch_{chunk_idx}_{file_idx}" / img_key + episode_durations = _save_batch_episodes_images( + dataset=dataset, + imgs_dir=imgs_dir, + img_key=img_key, + episode_indices=batch_episodes, + num_workers=num_workers, + ) + + # Encode all batched episodes into single video + video_path = new_meta.root / new_meta.video_path.format( + video_key=img_key, chunk_index=chunk_idx, file_index=file_idx + ) + video_path.parent.mkdir(parents=True, exist_ok=True) + + encode_video_frames( + imgs_dir=imgs_dir, + video_path=video_path, + fps=fps, + vcodec=vcodec, + pix_fmt=pix_fmt, + g=g, + crf=crf, + fast_decode=fast_decode, + overwrite=True, + ) + + # Clean up temporary images + shutil.rmtree(imgs_dir) + + # Update metadata for each episode in the batch + for ep_idx, duration in zip(batch_episodes, episode_durations, strict=True): + from_timestamp = cumulative_timestamp + to_timestamp = cumulative_timestamp + duration + cumulative_timestamp = to_timestamp + + # Find episode metadata entry and add video metadata (O(1) dictionary lookup) + ep_meta = all_episode_metadata[ep_idx] + ep_meta[f"videos/{img_key}/chunk_index"] = chunk_idx + ep_meta[f"videos/{img_key}/file_index"] = file_idx + ep_meta[f"videos/{img_key}/from_timestamp"] = from_timestamp + ep_meta[f"videos/{img_key}/to_timestamp"] = to_timestamp + + # Move to next video file for next batch + chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, new_meta.chunks_size) + cumulative_timestamp = 0.0 + + # Copy and transform data files (removing image columns) + _copy_data_without_images(dataset, new_meta, episode_indices, img_keys) + + # Save episode metadata + episodes_df = pd.DataFrame(list(all_episode_metadata.values())) + episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet" + episodes_path.parent.mkdir(parents=True, exist_ok=True) + episodes_df.to_parquet(episodes_path, index=False) + + # Update metadata info + new_meta.info["total_episodes"] = len(episode_indices) + new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata.values()) + new_meta.info["total_tasks"] = dataset.meta.total_tasks + new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"} + + # Update video info for all image keys (now videos) + # We need to manually set video info since update_video_info() checks video_keys first + for img_key in img_keys: + if not new_meta.features[img_key].get("info", None): + video_path = new_meta.root / new_meta.video_path.format( + video_key=img_key, chunk_index=0, file_index=0 + ) + new_meta.info["features"][img_key]["info"] = get_video_info(video_path) + + write_info(new_meta.info, new_meta.root) + + # Copy stats and tasks + if dataset.meta.stats is not None: + # Remove image stats + new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys} + write_stats(new_stats, new_meta.root) + + if dataset.meta.tasks is not None: + write_tasks(dataset.meta.tasks, new_meta.root) + + finally: + # Clean up temporary directory + if temp_dir.exists(): + shutil.rmtree(temp_dir) + + logging.info(f"Completed converting {dataset.repo_id} to video format") + logging.info(f"New dataset saved to: {output_dir}") + + # Return new dataset + return LeRobotDataset(repo_id=repo_id, root=output_dir) diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index e835b1de6..4ba6ce44f 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -66,23 +66,23 @@ Remove camera feature: --operation.type remove_feature \ --operation.feature_names "['observation.images.top']" -Convert image dataset to video format (saves locally): +Convert image dataset to video format and save locally: python -m lerobot.scripts.lerobot_edit_dataset \ --repo_id lerobot/pusht_image \ - --operation.type convert_to_video \ + --operation.type convert_image_to_video \ --operation.output_dir /path/to/output/pusht_video -Convert image dataset and save with new repo_id: +Convert image dataset to video format and save with new repo_id: python -m lerobot.scripts.lerobot_edit_dataset \ --repo_id lerobot/pusht_image \ --new_repo_id lerobot/pusht_video \ - --operation.type convert_to_video + --operation.type convert_image_to_video -Convert and push to hub: +Convert image dataset to video format and push to hub: python -m lerobot.scripts.lerobot_edit_dataset \ --repo_id lerobot/pusht_image \ --new_repo_id lerobot/pusht_video \ - --operation.type convert_to_video \ + --operation.type convert_image_to_video \ --push_to_hub true Using JSON config file: @@ -92,24 +92,19 @@ Using JSON config file: import logging import shutil -from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass from pathlib import Path -import pandas as pd -from tqdm import tqdm - from lerobot.configs import parser from lerobot.datasets.dataset_tools import ( + convert_image_to_video_dataset, delete_episodes, merge_datasets, remove_feature, split_dataset, ) -from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata -from lerobot.datasets.utils import write_stats, write_tasks -from lerobot.datasets.video_utils import encode_video_frames, get_video_info -from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.utils.constants import HF_LEROBOT_HOME from lerobot.utils.utils import init_logging @@ -138,8 +133,8 @@ class RemoveFeatureConfig: @dataclass -class ConvertToVideoConfig: - type: str = "convert_to_video" +class ConvertImageToVideoConfig: + type: str = "convert_image_to_video" output_dir: str | None = None vcodec: str = "libsvtav1" pix_fmt: str = "yuv420p" @@ -148,12 +143,16 @@ class ConvertToVideoConfig: fast_decode: int = 0 episode_indices: list[int] | None = None num_workers: int = 4 + max_episodes_per_batch: int | None = None + max_frames_per_batch: int | None = None @dataclass class EditDatasetConfig: repo_id: str - operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertToVideoConfig + operation: ( + DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig + ) root: str | None = None new_repo_id: str | None = None push_to_hub: bool = False @@ -297,362 +296,7 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None: LeRobotDataset(output_repo_id, root=output_dir).push_to_hub() -def save_episode_images_for_video( - dataset: LeRobotDataset, - imgs_dir: Path, - img_key: str, - episode_index: int, - num_workers: int = 4, -) -> None: - """Save images from a specific episode and camera to disk for video encoding. - - Args: - dataset: The LeRobot dataset to extract images from - imgs_dir: Directory to save images to - img_key: The image key (camera) to extract - episode_index: Index of the episode to save - num_workers: Number of threads for parallel image saving - """ - # Create directory - imgs_dir.mkdir(parents=True, exist_ok=True) - - # Get dataset without torch format for PIL image access - hf_dataset = dataset.hf_dataset.with_format(None) - - # Select only this camera's images - imgs_dataset = hf_dataset.select_columns(img_key) - - # Get episode start and end indices - from_idx = dataset.meta.episodes["dataset_from_index"][episode_index] - to_idx = dataset.meta.episodes["dataset_to_index"][episode_index] - - # Get all items for this episode - episode_dataset = imgs_dataset.select(range(from_idx, to_idx)) - - # Define function to save a single image - def save_single_image(i_item_tuple): - i, item = i_item_tuple - img = item[img_key] - # Use frame-XXXXXX.png format to match encode_video_frames expectations - img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100) - return i - - # Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png) - items = list(enumerate(episode_dataset)) - - with ThreadPoolExecutor(max_workers=num_workers) as executor: - futures = [executor.submit(save_single_image, item) for item in items] - for future in as_completed(futures): - future.result() # This will raise any exceptions that occurred - - -def encode_episode_videos( - dataset: LeRobotDataset, - new_meta: LeRobotDatasetMetadata, - episode_index: int, - vcodec: str, - pix_fmt: str, - g: int, - crf: int, - fast_decode: int, - temp_dir: Path, - num_image_workers: int = 4, -) -> dict[str, dict]: - """Encode videos for a single episode and return video metadata. - - Args: - dataset: Source dataset with images - new_meta: Metadata object for the new video dataset - episode_index: Episode index to process - vcodec: Video codec - pix_fmt: Pixel format - g: Group of pictures size - crf: Constant rate factor - fast_decode: Fast decode tuning - temp_dir: Temporary directory for images - num_image_workers: Number of workers for saving images - - Returns: - Dictionary mapping video keys to their metadata (chunk_index, file_index, timestamps) - """ - hf_dataset = dataset.hf_dataset.with_format(None) - img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)] - - video_metadata = {} - fps = int(dataset.fps) # Convert to int for PyAV compatibility - episode_length = dataset.meta.episodes["length"][episode_index] - episode_duration = episode_length / dataset.fps # Use original fps for duration calculation - - for img_key in img_keys: - # Save images temporarily - imgs_dir = temp_dir / f"episode_{episode_index:06d}" / img_key - save_episode_images_for_video(dataset, imgs_dir, img_key, episode_index, num_image_workers) - - # Determine chunk and file indices - # For simplicity, we'll put each episode in its own file - chunk_idx = episode_index // new_meta.chunks_size - file_idx = episode_index % new_meta.chunks_size - - # Create video path in the new dataset structure - video_path = new_meta.root / new_meta.video_path.format( - video_key=img_key, chunk_index=chunk_idx, file_index=file_idx - ) - video_path.parent.mkdir(parents=True, exist_ok=True) - - # Encode video - encode_video_frames( - imgs_dir=imgs_dir, - video_path=video_path, - fps=fps, - vcodec=vcodec, - pix_fmt=pix_fmt, - g=g, - crf=crf, - fast_decode=fast_decode, - overwrite=True, - ) - - # Clean up temporary images - shutil.rmtree(imgs_dir) - - # Store video metadata - video_metadata[img_key] = { - f"videos/{img_key}/chunk_index": chunk_idx, - f"videos/{img_key}/file_index": file_idx, - f"videos/{img_key}/from_timestamp": 0.0, - f"videos/{img_key}/to_timestamp": episode_duration, - } - - return video_metadata - - -def convert_dataset_to_videos( - dataset: LeRobotDataset, - output_dir: Path, - repo_id: str | None = None, - vcodec: str = "libsvtav1", - pix_fmt: str = "yuv420p", - g: int = 2, - crf: int = 30, - fast_decode: int = 0, - episode_indices: list[int] | None = None, - num_workers: int = 4, -) -> LeRobotDataset: - """Convert image-based dataset to video-based dataset. - - Creates a new LeRobotDataset with videos instead of images, following the proper - LeRobot dataset structure with videos stored in chunked MP4 files. - - Args: - dataset: The source LeRobot dataset with images - output_dir: Directory to save the new video dataset - repo_id: Repository ID for the new dataset (default: original_id + "_video") - vcodec: Video codec (default: libsvtav1) - pix_fmt: Pixel format (default: yuv420p) - g: Group of pictures size (default: 2) - crf: Constant rate factor (default: 30) - fast_decode: Fast decode tuning (default: 0) - episode_indices: List of episode indices to convert (None = all episodes) - num_workers: Number of threads for parallel processing (default: 4) - - Returns: - New LeRobotDataset with videos - """ - # Check that it's an image dataset - if len(dataset.meta.video_keys) > 0: - raise ValueError( - f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}" - ) - - # Get all image keys - hf_dataset = dataset.hf_dataset.with_format(None) - img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)] - - if len(img_keys) == 0: - raise ValueError(f"No image keys found in dataset {dataset.repo_id}") - - # Determine which episodes to process - if episode_indices is None: - episode_indices = list(range(dataset.meta.total_episodes)) - - if repo_id is None: - repo_id = f"{dataset.repo_id}_video" - - logging.info( - f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}" - ) - logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}") - - # Create new features dict, converting image features to video features - new_features = {} - for key, value in dataset.meta.features.items(): - if key not in img_keys: - new_features[key] = value - else: - # Convert image key to video format - new_features[key] = value.copy() - new_features[key]["dtype"] = "video" # Change dtype from "image" to "video" - # Video info will be updated after episodes are encoded - - # Create new metadata for video dataset - new_meta = LeRobotDatasetMetadata.create( - repo_id=repo_id, - fps=dataset.meta.fps, - features=new_features, - robot_type=dataset.meta.robot_type, - root=output_dir, - use_videos=True, - chunks_size=dataset.meta.chunks_size, - data_files_size_in_mb=dataset.meta.data_files_size_in_mb, - video_files_size_in_mb=dataset.meta.video_files_size_in_mb, - ) - - # Create temporary directory for image extraction - temp_dir = output_dir / "temp_images" - temp_dir.mkdir(parents=True, exist_ok=True) - - # Process each episode - all_episode_metadata = [] - - try: - for ep_idx in tqdm(episode_indices, desc="Converting episodes to videos"): - # Get episode metadata from source - src_episode = dataset.meta.episodes[ep_idx] - - # Encode videos for this episode - video_metadata = encode_episode_videos( - dataset=dataset, - new_meta=new_meta, - episode_index=ep_idx, - vcodec=vcodec, - pix_fmt=pix_fmt, - g=g, - crf=crf, - fast_decode=fast_decode, - temp_dir=temp_dir, - num_image_workers=num_workers, - ) - - # Build episode metadata - episode_meta = { - "episode_index": ep_idx, - "length": src_episode["length"], - "dataset_from_index": ep_idx * src_episode["length"], - "dataset_to_index": (ep_idx + 1) * src_episode["length"], - } - - # Add video metadata - for img_key in img_keys: - episode_meta.update(video_metadata[img_key]) - - # Add data chunk/file info (using same structure as source) - if "data/chunk_index" in src_episode: - episode_meta["data/chunk_index"] = src_episode["data/chunk_index"] - episode_meta["data/file_index"] = src_episode["data/file_index"] - - all_episode_metadata.append(episode_meta) - - # Copy and transform data files (removing image columns) - _copy_data_without_images(dataset, new_meta, episode_indices, img_keys) - - # Save episode metadata - episodes_df = pd.DataFrame(all_episode_metadata) - episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet" - episodes_path.parent.mkdir(parents=True, exist_ok=True) - episodes_df.to_parquet(episodes_path, index=False) - - # Update metadata info - new_meta.info["total_episodes"] = len(episode_indices) - new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata) - new_meta.info["total_tasks"] = dataset.meta.total_tasks - new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"} - - # Update video info for all image keys (now videos) - # We need to manually set video info since update_video_info() checks video_keys first - for img_key in img_keys: - if not new_meta.features[img_key].get("info", None): - video_path = new_meta.root / new_meta.video_path.format( - video_key=img_key, chunk_index=0, file_index=0 - ) - new_meta.info["features"][img_key]["info"] = get_video_info(video_path) - - from lerobot.datasets.utils import write_info - - write_info(new_meta.info, new_meta.root) - - # Copy stats and tasks - if dataset.meta.stats is not None: - # Remove image stats - new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys} - write_stats(new_stats, new_meta.root) - - if dataset.meta.tasks is not None: - write_tasks(dataset.meta.tasks, new_meta.root) - - finally: - # Clean up temporary directory - if temp_dir.exists(): - shutil.rmtree(temp_dir) - - logging.info(f"✓ Completed converting {dataset.repo_id} to video format") - logging.info(f"New dataset saved to: {output_dir}") - - # Return new dataset - return LeRobotDataset(repo_id=repo_id, root=output_dir) - - -def _copy_data_without_images( - src_dataset: LeRobotDataset, - dst_meta: LeRobotDatasetMetadata, - episode_indices: list[int], - img_keys: list[str], -) -> None: - """Copy data files without image columns. - - Args: - src_dataset: Source dataset - dst_meta: Destination metadata - episode_indices: Episodes to include - img_keys: Image keys to remove - """ - from lerobot.datasets.utils import DATA_DIR - - data_dir = src_dataset.root / DATA_DIR - parquet_files = sorted(data_dir.glob("*/*.parquet")) - - if not parquet_files: - raise ValueError(f"No parquet files found in {data_dir}") - - episode_set = set(episode_indices) - - for src_path in tqdm(parquet_files, desc="Processing data files"): - df = pd.read_parquet(src_path).reset_index(drop=True) - - # Filter to only include selected episodes - df = df[df["episode_index"].isin(episode_set)].copy() - - if len(df) == 0: - continue - - # Remove image columns - columns_to_drop = [col for col in img_keys if col in df.columns] - if columns_to_drop: - df = df.drop(columns=columns_to_drop) - - # Get chunk and file indices from path - relative_path = src_path.relative_to(src_dataset.root) - chunk_dir = relative_path.parts[1] - file_name = relative_path.parts[2] - chunk_idx = int(chunk_dir.split("-")[1]) - file_idx = int(file_name.split("-")[1].split(".")[0]) - - # Write to destination without pandas index - dst_path = dst_meta.root / f"data/chunk-{chunk_idx:03d}/file-{file_idx:03d}.parquet" - dst_path.parent.mkdir(parents=True, exist_ok=True) - df.to_parquet(dst_path, index=False) - - -def handle_convert_to_video(cfg: EditDatasetConfig) -> None: +def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None: # Note: Parser may create any config type with the right fields, so we access fields directly # instead of checking isinstance() dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) @@ -664,8 +308,12 @@ def handle_convert_to_video(cfg: EditDatasetConfig) -> None: if cfg.new_repo_id: # Use new_repo_id for both local storage and hub push output_repo_id = cfg.new_repo_id - output_dir = Path(cfg.root) / cfg.new_repo_id if cfg.root else HF_LEROBOT_HOME / cfg.new_repo_id - logging.info(f"Saving to new dataset: {cfg.new_repo_id}") + # Place new dataset as a sibling to the original dataset + # Get the parent of the actual dataset root (not cfg.root which might be the lerobot cache dir) + # Extract just the dataset name (after last slash) for the local directory + local_dir_name = cfg.new_repo_id.split("/")[-1] + output_dir = dataset.root.parent / local_dir_name + logging.info(f"Saving to new dataset: {cfg.new_repo_id} at {output_dir}") elif output_dir_config: # Use custom output directory for local-only storage output_dir = Path(output_dir_config) @@ -675,12 +323,15 @@ def handle_convert_to_video(cfg: EditDatasetConfig) -> None: else: # Auto-generate name: append "_video" to original repo_id output_repo_id = f"{cfg.repo_id}_video" - output_dir = Path(cfg.root) / output_repo_id if cfg.root else HF_LEROBOT_HOME / output_repo_id + # Place new dataset as a sibling to the original dataset + # Extract just the dataset name (after last slash) for the local directory + local_dir_name = output_repo_id.split("/")[-1] + output_dir = dataset.root.parent / local_dir_name logging.info(f"Saving to auto-generated location: {output_dir}") logging.info(f"Converting dataset {cfg.repo_id} to video format") - new_dataset = convert_dataset_to_videos( + new_dataset = convert_image_to_video_dataset( dataset=dataset, output_dir=output_dir, repo_id=output_repo_id, @@ -691,6 +342,8 @@ def handle_convert_to_video(cfg: EditDatasetConfig) -> None: fast_decode=getattr(cfg.operation, "fast_decode", 0), episode_indices=getattr(cfg.operation, "episode_indices", None), num_workers=getattr(cfg.operation, "num_workers", 4), + max_episodes_per_batch=getattr(cfg.operation, "max_episodes_per_batch", None), + max_frames_per_batch=getattr(cfg.operation, "max_frames_per_batch", None), ) logging.info("Video dataset created successfully!") @@ -718,8 +371,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None: handle_merge(cfg) elif operation_type == "remove_feature": handle_remove_feature(cfg) - elif operation_type == "convert_to_video": - handle_convert_to_video(cfg) + elif operation_type == "convert_image_to_video": + handle_convert_image_to_video(cfg) else: raise ValueError( f"Unknown operation type: {operation_type}\n" diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py index 3a4516fc8..35a369de9 100644 --- a/tests/datasets/test_dataset_tools.py +++ b/tests/datasets/test_dataset_tools.py @@ -29,7 +29,7 @@ from lerobot.datasets.dataset_tools import ( remove_feature, split_dataset, ) -from lerobot.scripts.lerobot_edit_dataset import convert_dataset_to_videos +from lerobot.scripts.lerobot_edit_dataset import convert_image_to_video_dataset @pytest.fixture @@ -1050,7 +1050,7 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path): assert "reward" in modified_dataset.meta.features -def test_convert_dataset_to_videos(tmp_path): +def test_convert_image_to_video_dataset(tmp_path): """Test converting lerobot/pusht_image dataset to video format.""" from lerobot.datasets.lerobot_dataset import LeRobotDataset @@ -1071,7 +1071,7 @@ def test_convert_dataset_to_videos(tmp_path): assert "observation.image" in source_dataset.meta.features # Convert to video dataset (only first 2 episodes for speed) - video_dataset = convert_dataset_to_videos( + video_dataset = convert_image_to_video_dataset( dataset=source_dataset, output_dir=output_dir, repo_id="lerobot/pusht_video", @@ -1113,7 +1113,7 @@ def test_convert_dataset_to_videos(tmp_path): shutil.rmtree(output_dir) -def test_convert_dataset_to_videos_subset_episodes(tmp_path): +def test_convert_image_to_video_dataset_subset_episodes(tmp_path): """Test converting only specific episodes from lerobot/pusht_image to video format.""" from lerobot.datasets.lerobot_dataset import LeRobotDataset @@ -1132,7 +1132,7 @@ def test_convert_dataset_to_videos_subset_episodes(tmp_path): # Convert only episode 0 to video (subset of loaded episodes) episode_indices = [0] - video_dataset = convert_dataset_to_videos( + video_dataset = convert_image_to_video_dataset( dataset=source_dataset, output_dir=output_dir, repo_id="lerobot/pusht_video_subset", From 13bfee1aa486a6ca64eb410332a0bc47d008c451 Mon Sep 17 00:00:00 2001 From: Alexis D Date: Tue, 20 Jan 2026 02:20:30 -0800 Subject: [PATCH 054/111] Set 10 direction bit for Current Load attribute (#1014) --- src/lerobot/motors/feetech/tables.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lerobot/motors/feetech/tables.py b/src/lerobot/motors/feetech/tables.py index 91e844a72..56500e527 100644 --- a/src/lerobot/motors/feetech/tables.py +++ b/src/lerobot/motors/feetech/tables.py @@ -205,6 +205,7 @@ MODEL_BAUDRATE_TABLE = { # Sign-Magnitude encoding bits STS_SMS_SERIES_ENCODINGS_TABLE = { + "Present_Load": 10, "Homing_Offset": 11, "Goal_Position": 15, "Goal_Velocity": 15, From d36dfcdf7155debdaa17cc8fbae35eb150dc965f Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Tue, 20 Jan 2026 15:00:45 +0100 Subject: [PATCH 055/111] fix(discord link): fixing discord link in CONTRIBUTING.md (#2826) Signed-off-by: Caroline Pascal --- CONTRIBUTING.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index abca0d821..c51a48831 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -14,7 +14,7 @@ You can contribute in many ways: - **Documentation:** Improve examples, guides, and docstrings. - **Feedback:** Submit tickets related to bugs or desired new features. -If you are unsure where to start, join our [Discord Channel](https://discord.gg/JkrYNdmw). +If you are unsure where to start, join our [Discord Channel](https://discord.gg/q8Dzzpym3f). ## Development Setup From 9919b16b3693a756388dd2cccdd04396308fd0ae Mon Sep 17 00:00:00 2001 From: sato_shinji Date: Tue, 20 Jan 2026 14:17:38 +0000 Subject: [PATCH 056/111] fix: ensure action tensors are moved to client_device in async training (#2792) * feat(async_inference): server always sends CPU tensors, client handles device conversion * fix:fix the type annotation of RawObservation in src/lerobot/async_inference/helpers.py * update the import of robot_client --------- Co-authored-by: Sato shinji Co-authored-by: Steven Palma Co-authored-by: KB --- docs/source/async.mdx | 1 + examples/tutorial/async-inf/robot_client.py | 1 + src/lerobot/async_inference/configs.py | 10 ++++++++++ src/lerobot/async_inference/helpers.py | 5 +++-- src/lerobot/async_inference/policy_server.py | 2 ++ src/lerobot/async_inference/robot_client.py | 20 ++++++++++++++++++-- tests/async_inference/test_e2e.py | 10 ++++++++-- 7 files changed, 43 insertions(+), 6 deletions(-) diff --git a/docs/source/async.mdx b/docs/source/async.mdx index 1d3e0edbf..3244fc2a3 100644 --- a/docs/source/async.mdx +++ b/docs/source/async.mdx @@ -195,6 +195,7 @@ client_cfg = RobotClientConfig( robot=robot_cfg, server_address="localhost:8080", policy_device="mps", + client_device="cpu", policy_type="smolvla", pretrained_name_or_path="/smolvla_async", chunk_size_threshold=0.5, diff --git a/examples/tutorial/async-inf/robot_client.py b/examples/tutorial/async-inf/robot_client.py index eb3751169..db6ead3fe 100644 --- a/examples/tutorial/async-inf/robot_client.py +++ b/examples/tutorial/async-inf/robot_client.py @@ -30,6 +30,7 @@ def main(): robot=robot_cfg, server_address=server_address, policy_device="mps", + client_device="cpu", policy_type="act", pretrained_name_or_path="/robot_learning_tutorial_act", chunk_size_threshold=0.5, # g diff --git a/src/lerobot/async_inference/configs.py b/src/lerobot/async_inference/configs.py index d1768a323..2e3fe576d 100644 --- a/src/lerobot/async_inference/configs.py +++ b/src/lerobot/async_inference/configs.py @@ -126,6 +126,12 @@ class RobotClientConfig: # Device configuration policy_device: str = field(default="cpu", metadata={"help": "Device for policy inference"}) + client_device: str = field( + default="cpu", + metadata={ + "help": "Device to move actions to after receiving from server (e.g., for downstream planners)" + }, + ) # Control behavior configuration chunk_size_threshold: float = field(default=0.5, metadata={"help": "Threshold for chunk size control"}) @@ -161,6 +167,9 @@ class RobotClientConfig: if not self.policy_device: raise ValueError("policy_device cannot be empty") + if not self.client_device: + raise ValueError("client_device cannot be empty") + if self.chunk_size_threshold < 0 or self.chunk_size_threshold > 1: raise ValueError(f"chunk_size_threshold must be between 0 and 1, got {self.chunk_size_threshold}") @@ -184,6 +193,7 @@ class RobotClientConfig: "policy_type": self.policy_type, "pretrained_name_or_path": self.pretrained_name_or_path, "policy_device": self.policy_device, + "client_device": self.client_device, "chunk_size_threshold": self.chunk_size_threshold, "fps": self.fps, "actions_per_chunk": self.actions_per_chunk, diff --git a/src/lerobot/async_inference/helpers.py b/src/lerobot/async_inference/helpers.py index 2158f51ac..8b12920d9 100644 --- a/src/lerobot/async_inference/helpers.py +++ b/src/lerobot/async_inference/helpers.py @@ -18,6 +18,7 @@ import os import time from dataclasses import dataclass, field from pathlib import Path +from typing import Any import torch @@ -39,8 +40,8 @@ from lerobot.utils.utils import init_logging Action = torch.Tensor -# observation as received from the robot -RawObservation = dict[str, torch.Tensor] +# observation as received from the robot (can be numpy arrays, floats, etc.) +RawObservation = dict[str, Any] # observation as those recorded in LeRobot dataset (keys are different) LeRobotObservation = dict[str, torch.Tensor] diff --git a/src/lerobot/async_inference/policy_server.py b/src/lerobot/async_inference/policy_server.py index ab2e6bcd8..aedce2a74 100644 --- a/src/lerobot/async_inference/policy_server.py +++ b/src/lerobot/async_inference/policy_server.py @@ -381,6 +381,8 @@ class PolicyServer(services_pb2_grpc.AsyncInferenceServicer): action_tensor = torch.stack(processed_actions, dim=1).squeeze(0) self.logger.debug(f"Postprocessed action shape: {action_tensor.shape}") + action_tensor = action_tensor.detach().cpu() + """5. Convert to TimedAction list""" action_chunk = self._time_action_chunk( observation_t.get_timestamp(), list(action_tensor), observation_t.get_timestep() diff --git a/src/lerobot/async_inference/robot_client.py b/src/lerobot/async_inference/robot_client.py index f26639dc1..e4d21652a 100644 --- a/src/lerobot/async_inference/robot_client.py +++ b/src/lerobot/async_inference/robot_client.py @@ -25,6 +25,7 @@ python src/lerobot/async_inference/robot_client.py \ --policy_type=act \ --pretrained_name_or_path=user/model \ --policy_device=mps \ + --client_device=cpu \ --actions_per_chunk=50 \ --chunk_size_threshold=0.5 \ --aggregate_fn_name=weighted_average \ @@ -40,6 +41,7 @@ from collections.abc import Callable from dataclasses import asdict from pprint import pformat from queue import Queue +from typing import Any import draccus import grpc @@ -47,7 +49,6 @@ import torch from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 -from lerobot.processor import RobotAction from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -285,6 +286,21 @@ class RobotClient: timed_actions = pickle.loads(actions_chunk.data) # nosec deserialize_time = time.perf_counter() - deserialize_start + # Log device type of received actions + if len(timed_actions) > 0: + received_device = timed_actions[0].get_action().device.type + self.logger.debug(f"Received actions on device: {received_device}") + + # Move actions to client_device (e.g., for downstream planners that need GPU) + client_device = self.config.client_device + if client_device != "cpu": + for timed_action in timed_actions: + if timed_action.get_action().device.type != client_device: + timed_action.action = timed_action.get_action().to(client_device) + self.logger.debug(f"Converted actions to device: {client_device}") + else: + self.logger.debug(f"Actions kept on device: {client_device}") + self.action_chunk_size = max(self.action_chunk_size, len(timed_actions)) # Calculate network latency if we have matching observations @@ -351,7 +367,7 @@ class RobotClient: action = {key: action_tensor[i].item() for i, key in enumerate(self.robot.action_features)} return action - def control_loop_action(self, verbose: bool = False) -> RobotAction: + def control_loop_action(self, verbose: bool = False) -> dict[str, Any]: """Reading and performing actions in local queue""" # Lock only for queue operations diff --git a/tests/async_inference/test_e2e.py b/tests/async_inference/test_e2e.py index 11941ce32..54ca29b48 100644 --- a/tests/async_inference/test_e2e.py +++ b/tests/async_inference/test_e2e.py @@ -144,12 +144,18 @@ def test_async_inference_e2e(monkeypatch): client = RobotClient(client_config) assert client.start(), "Client failed initial handshake with the server" - # Track action chunks received without modifying RobotClient - action_chunks_received = {"count": 0} + # Track action chunks received and verify device type + action_chunks_received = {"count": 0, "actions_on_cpu": True} original_aggregate = client._aggregate_action_queues def counting_aggregate(*args, **kwargs): action_chunks_received["count"] += 1 + # Check that all received actions are on CPU + if args: + for timed_action in args[0]: # args[0] is the list of TimedAction + action_tensor = timed_action.get_action() + if action_tensor.device.type != "cpu": + action_chunks_received["actions_on_cpu"] = False return original_aggregate(*args, **kwargs) monkeypatch.setattr(client, "_aggregate_action_queues", counting_aggregate) From 9ca680dce28f2a58c6af12a62980b1ae86b1659f Mon Sep 17 00:00:00 2001 From: Tommy in Tongji <36354458+TommyZihao@users.noreply.github.com> Date: Wed, 21 Jan 2026 00:54:24 +0800 Subject: [PATCH 057/111] Update README.md (#2827) Add Chinese doc link. Signed-off-by: Tommy in Tongji <36354458+TommyZihao@users.noreply.github.com> --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 57fec2e5f..d60cd35a9 100644 --- a/README.md +++ b/README.md @@ -128,6 +128,7 @@ Learn how to implement your own simulation environment or benchmark and distribu ## Resources - **[Documentation](https://huggingface.co/docs/lerobot/index):** The complete guide to tutorials & API. +- **[Chinese Tutorials: LeRobot+SO-ARM101中文教程-同济子豪兄](https://zihao-ai.feishu.cn/wiki/space/7589642043471924447)** Detailed doc for assembling, teleoperate, dataset, train, deploy. Verified by Seed Studio and 5 global hackathon players. - **[Discord](https://discord.gg/q8Dzzpym3f):** Join the `LeRobot` server to discuss with the community. - **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments. - **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot. From 0b067df57d21d3a02d6c511f1609172fa39ac29b Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 20 Jan 2026 18:02:38 +0100 Subject: [PATCH 058/111] feat(robots): add context managers (#2828) --- src/lerobot/robots/robot.py | 26 +++++++++++++++++++++++ src/lerobot/teleoperators/teleoperator.py | 26 +++++++++++++++++++++++ 2 files changed, 52 insertions(+) diff --git a/src/lerobot/robots/robot.py b/src/lerobot/robots/robot.py index d1021daf4..d165886b9 100644 --- a/src/lerobot/robots/robot.py +++ b/src/lerobot/robots/robot.py @@ -58,6 +58,32 @@ class Robot(abc.ABC): def __str__(self) -> str: return f"{self.id} {self.__class__.__name__}" + def __enter__(self): + """ + Context manager entry. + Automatically connects to the camera. + """ + self.connect() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + """ + Context manager exit. + Automatically disconnects, ensuring resources are released even on error. + """ + self.disconnect() + + def __del__(self) -> None: + """ + Destructor safety net. + Attempts to disconnect if the object is garbage collected without cleanup. + """ + try: + if self.is_connected: + self.disconnect() + except Exception: # nosec B110 + pass + # TODO(aliberts): create a proper Feature class for this that links with datasets @property @abc.abstractmethod diff --git a/src/lerobot/teleoperators/teleoperator.py b/src/lerobot/teleoperators/teleoperator.py index cd9e3a53d..847b88b7f 100644 --- a/src/lerobot/teleoperators/teleoperator.py +++ b/src/lerobot/teleoperators/teleoperator.py @@ -58,6 +58,32 @@ class Teleoperator(abc.ABC): def __str__(self) -> str: return f"{self.id} {self.__class__.__name__}" + def __enter__(self): + """ + Context manager entry. + Automatically connects to the camera. + """ + self.connect() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + """ + Context manager exit. + Automatically disconnects, ensuring resources are released even on error. + """ + self.disconnect() + + def __del__(self) -> None: + """ + Destructor safety net. + Attempts to disconnect if the object is garbage collected without cleanup. + """ + try: + if self.is_connected: + self.disconnect() + except Exception: # nosec B110 + pass + @property @abc.abstractmethod def action_features(self) -> dict: From 961277d86e2467257c3916257baa4817632f2ed5 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 22 Jan 2026 12:24:12 +0100 Subject: [PATCH 059/111] chore(dependencies): Bump lerobot to 0.4.4 (#2840) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fa4b22bdf..75f617e75 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb" [project] name = "lerobot" -version = "0.4.3" +version = "0.4.4" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" dynamic = ["readme"] license = { text = "Apache-2.0" } From 6d34a986de44c5f22a9a99ed514f1b16832c3f32 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 22 Jan 2026 12:26:17 +0100 Subject: [PATCH 060/111] feat(ci): trigger manually documentation release version (#2841) --- .github/workflows/documentation.yml | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index 48a10e4bc..c7926c542 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -18,6 +18,11 @@ name: Documentation on: # Allows running this workflow manually from the Actions tab workflow_dispatch: + inputs: + version: + description: 'Version tag (e.g. v0.1.2) - Leave empty for standard main build' + required: false + type: string # Triggers the workflow on push events to main for the docs folder push: @@ -54,7 +59,13 @@ jobs: with: commit_sha: ${{ github.sha }} package: lerobot - additional_args: --not_python_module ${{ github.event_name == 'release' && format('--version {0}', github.event.release.tag_name) || '' }} + additional_args: >- + --not_python_module + ${{ + (github.event_name == 'release' && format('--version {0}', github.event.release.tag_name)) || + (inputs.version != '' && format('--version {0}', inputs.version)) || + '' + }} secrets: token: ${{ secrets.HUGGINGFACE_PUSH }} hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} From 9e10eb4a77856fb2dfb6784c4afc3984748387f7 Mon Sep 17 00:00:00 2001 From: Woojin Wie Date: Mon, 26 Jan 2026 06:29:37 +0900 Subject: [PATCH 061/111] fix(robots): update gripper configuration and calibration settings for OMX (#2815) --- docs/source/_toctree.yml | 2 + docs/source/omx.mdx | 197 ++++++++++++++++++ .../omx_leader/config_omx_leader.py | 2 +- .../teleoperators/omx_leader/omx_leader.py | 10 +- 4 files changed, 209 insertions(+), 2 deletions(-) create mode 100644 docs/source/omx.mdx diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 2b8086cd7..4298758b5 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -99,6 +99,8 @@ title: Unitree G1 - local: earthrover_mini_plus title: Earth Rover Mini + - local: omx + title: OMX title: "Robots" - sections: - local: phone_teleop diff --git a/docs/source/omx.mdx b/docs/source/omx.mdx new file mode 100644 index 000000000..4617ac7bd --- /dev/null +++ b/docs/source/omx.mdx @@ -0,0 +1,197 @@ +## Order and Assemble the parts + +First, assemble the OMX hardware following the official assembly guide. + +OMX Assembly Guide: https://ai.robotis.com/omx/assembly_guide_omx.html + +OMX robots are shipped preconfigured from the factory. Motor IDs, communication parameters, and joint offsets are already set, so no additional motor setup or calibration is required before using LeRobot. + +## Install LeRobot 🤗 + +To install LeRobot, follow our [Installation Guide](./installation) + +In addition to these instructions, you need to install the Dynamixel SDK: + +```bash +pip install -e ".[dynamixel]" +``` + +## Connect the robot + +To find the port for each bus servo adapter, run this script: + +```bash +lerobot-find-port +``` + +This command runs and when prompted, disconnect the USB cable from either the leader or follower arm and press Enter. The output will show 'The port of this MotorsBus is [port]'. This identifies the port for the disconnected arm. Repeat for the other arm to identify both ports. + + + + +Example output on macOS: + +``` +Finding all available ports for the MotorBus. +['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] +Remove the USB cable from your MotorsBus and press Enter when done. + +[...Disconnect corresponding leader or follower arm and press Enter...] + +The port of this MotorsBus is /dev/tty.usbmodem575E0032081 +Reconnect the USB cable. +``` + +Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm. + + + + +On Linux, we strongly recommend using udev rules to assign persistent and human-readable device names to the OMX leader and follower arms. This avoids issues where device names such as ttyACM0 and ttyACM1 change when the robot is unplugged, replugged, or when the system is rebooted. + +#### 1. Find your device serial numbers + +You should have obtained the port numbers like ../../ttyACM? for the leader and follower using `lerobot-find-port`. You can match those results with the serial numbers using the `ls -l /dev/serial/by-id/` command. +To create udev rules, you need the unique serial number for each OMX device. The easiest way is to list devices under: + +```bash +ls -l /dev/serial/by-id/ +``` + +You will see output similar to: + +```bash +usb-ROBOTIS_OpenRB-150_228BDD7B503059384C2E3120FF0A2B19-if00 -> ../../ttyACM0 +usb-ROBOTIS_OpenRB-150_67E1ED68503059384C2E3120FF092234-if00 -> ../../ttyACM1 +``` + +In each line, the serial number is the long string after `usb-ROBOTIS_OpenRB-150_` and before `-if00`. + +Follower serial: `228BDD7B503059384C2E3120FF0A2B19` + +Leader serial: `67E1ED68503059384C2E3120FF092234` + +#### 2. Create the udev rule + +Create a new udev rule file: + +```bash +sudo nano /etc/udev/rules.d/99-omx.rules +``` + +Paste the following lines, replacing the serial numbers with the values you found above: + +```bash +SUBSYSTEM=="tty", ATTRS{idVendor}=="0403", ATTRS{serial}=="228BDD7B503059384C2E3120FF0A2B19", SYMLINK+="omx_follower" +SUBSYSTEM=="tty", ATTRS{idVendor}=="0403", ATTRS{serial}=="67E1ED68503059384C2E3120FF092234", SYMLINK+="omx_leader" +``` + +Save the file and reload udev rules: + +```bash +sudo udevadm control --reload-rules +sudo udevadm trigger +``` + +Now unplug and replug both devices once. + +#### 3. Verify the symlinks + +Check that the persistent device names exist: + +```bash +ls -l /dev/omx_follower /dev/omx_leader +``` + +You should see them pointing to ttyACM\* devices: + +```bash +/dev/omx_follower -> ttyACM* +/dev/omx_leader -> ttyACM* +``` + +These names remain stable across reboots and reconnections. + + + + +## Teleoperate + +After identifying the correct ports, you can directly teleoperate the follower arm using the leader arm. + + + + +### Teleoperate without camera + +```bash +lerobot-teleoperate \ + --robot.type=omx_follower \ + --robot.port= \ + --robot.id=omx_follower_arm \ + --teleop.type=omx_leader \ + --teleop.port= \ + --teleop.id=omx_leader_arm +``` + +During teleoperation, motions of the leader arm are mirrored in real time by the follower arm. OMX is already preconfigured, teleoperation can begin immediately without any calibration steps. + +### Teleoperate with camera + +You can also enable camera input during teleoperation by providing a camera configuration for the follower arm. + +```bash +lerobot-teleoperate \ + --robot.type=omx_follower \ + --robot.port= \ + --robot.id=omx_follower_arm \ + --robot.cameras="{front: {type: opencv, index_or_path: '/dev/video0', width: 640, height: 480, fps: 30}}" \ + --teleop.type=omx_leader \ + --teleop.port= \ + --teleop.id=omx_leader_arm \ + --display_data=true +``` + +When the camera is enabled, the camera stream is displayed in real time and synchronized with the robot state. This setup is useful for visual monitoring and can be reused later for demonstration recording and imitation learning. + + + + +### Teleoperate without camera + +```bash +lerobot-teleoperate \ + --robot.type=omx_follower \ + --robot.port=/dev/omx_follower \ + --robot.id=omx_follower_arm \ + --teleop.type=omx_leader \ + --teleop.port=/dev/omx_leader \ + --teleop.id=omx_leader_arm +``` + +During teleoperation, motions of the leader arm are mirrored in real time by the follower arm. OMX is already preconfigured, teleoperation can begin immediately without any calibration steps. + +### Teleoperate with camera + +You can also enable camera input during teleoperation by providing a camera configuration for the follower arm. + +```bash +lerobot-teleoperate \ + --robot.type=omx_follower \ + --robot.port=/dev/omx_follower \ + --robot.id=omx_follower_arm \ + --robot.cameras="{front: {type: opencv, index_or_path: '/dev/video0', width: 640, height: 480, fps: 30}}" \ + --teleop.type=omx_leader \ + --teleop.port=/dev/omx_leader \ + --teleop.id=omx_leader_arm \ + --display_data=true +``` + +When the camera is enabled, the camera stream is displayed in real time and synchronized with the robot state. This setup is useful for visual monitoring and can be reused later for demonstration recording and imitation learning. + + + + +Congrats 🎉, your robot is all set to learn a task on its own. + +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/robotis). diff --git a/src/lerobot/teleoperators/omx_leader/config_omx_leader.py b/src/lerobot/teleoperators/omx_leader/config_omx_leader.py index 3c0420ab2..a0eca38f7 100644 --- a/src/lerobot/teleoperators/omx_leader/config_omx_leader.py +++ b/src/lerobot/teleoperators/omx_leader/config_omx_leader.py @@ -27,4 +27,4 @@ class OmxLeaderConfig(TeleoperatorConfig): # Sets the arm in torque mode with the gripper motor set to this value. This makes it possible to squeeze # the gripper and have it spring back to an open position on its own. - gripper_open_pos: float = 37.0 + gripper_open_pos: float = 60.0 diff --git a/src/lerobot/teleoperators/omx_leader/omx_leader.py b/src/lerobot/teleoperators/omx_leader/omx_leader.py index 4423be714..4264b0485 100644 --- a/src/lerobot/teleoperators/omx_leader/omx_leader.py +++ b/src/lerobot/teleoperators/omx_leader/omx_leader.py @@ -103,7 +103,7 @@ class OmxLeader(Teleoperator): self.calibration[motor] = MotorCalibration( id=m.id, drive_mode=drive_modes[motor], - homing_offset=0, + homing_offset=0 if motor != "gripper" else 100, range_min=0, range_max=4095, ) @@ -123,12 +123,20 @@ class OmxLeader(Teleoperator): # point self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value) + if motor == "gripper": + self.bus.write("Drive_Mode", motor, DriveMode.INVERTED.value) + else: + self.bus.write("Drive_Mode", motor, DriveMode.NON_INVERTED.value) + # Use 'position control current based' for gripper to be limited by the limit of the current. # For the follower gripper, it means it can grasp an object without forcing too much even tho, # its goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch). # For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger # to make it move, and it will move back to its original target position when we release the force. self.bus.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value) + self.bus.write("Current_Limit", "gripper", 100) + self.bus.write("Goal_Current", "gripper", 100) + self.bus.write("Homing_Offset", "gripper", 100) # Set gripper's goal pos in current position mode so that we can use it as a trigger. self.bus.enable_torque("gripper") if self.is_calibrated: From 366bef915cdff80d0a1bcf4834130eb63a479cb5 Mon Sep 17 00:00:00 2001 From: Reece O'Mahoney <66252930+reeceomahoney@users.noreply.github.com> Date: Mon, 26 Jan 2026 16:26:49 +0000 Subject: [PATCH 062/111] add task ids to libero env cfg (#2842) --- src/lerobot/envs/configs.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index 112d3a73f..cd88b37bc 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -260,6 +260,7 @@ class HILSerlRobotEnvConfig(EnvConfig): @dataclass class LiberoEnv(EnvConfig): task: str = "libero_10" # can also choose libero_spatial, libero_object, etc. + task_ids: list[int] | None = None fps: int = 30 episode_length: int | None = None obs_type: str = "pixels_agent_pos" @@ -338,10 +339,10 @@ class LiberoEnv(EnvConfig): @property def gym_kwargs(self) -> dict: - return { - "obs_type": self.obs_type, - "render_mode": self.render_mode, - } + kwargs: dict[str, Any] = {"obs_type": self.obs_type, "render_mode": self.render_mode} + if self.task_ids is not None: + kwargs["task_ids"] = self.task_ids + return kwargs @EnvConfig.register_subclass("metaworld") From 9cfb5ce5468d0f2df568f7ba3f902f13f89802a7 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 26 Jan 2026 17:53:25 +0100 Subject: [PATCH 063/111] feat(motors): add damiao motors & can bus (#2788) * fix(motors): cleanup imports + fix signatures * feat(motors): add damiao canbus + multiple fixes * fix(motors): address comments -> last_state + different gains + sleep * refactor(motors): reduce duplicated code + adressed some comments in the PR * chore(motors): better timeouts * tests(motors): damiao test and imports * chore(deps): fix space * Apply suggestions from code review Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Signed-off-by: Steven Palma * chore(motors): remove normalization tables damiao * fix(motors): imports and signatures * feat(motors): add motor_type_str + recv_id to motor class and _get_motor_recv_id raises if no motor_obj.recv_id * chore(motors): remove normalize from base motor class and damaio * tests(motors): remove bad tests (to be replaced) * chore(motors): updated import check * use constant for kp and kd range and check responses in mit_control_batch() * Add docs on setting up canbus and use damiao otor bus, also add lerobot_setup_can.py and log if there is not response from a write command * precommit format * supress bandit as these are intentional cli commands * fix setup-can * add test * skip test in ci * nit precommit * update doc example * dont import can for tests --------- Signed-off-by: Steven Palma Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Co-authored-by: Pepijn --- docs/source/_toctree.yml | 2 + docs/source/damiao.mdx | 165 +++++ pyproject.toml | 3 + src/lerobot/motors/__init__.py | 6 +- src/lerobot/motors/calibration_gui.py | 2 +- src/lerobot/motors/damiao/__init__.py | 18 + src/lerobot/motors/damiao/damiao.py | 808 ++++++++++++++++++++++ src/lerobot/motors/damiao/tables.py | 209 ++++++ src/lerobot/motors/dynamixel/dynamixel.py | 11 +- src/lerobot/motors/feetech/feetech.py | 13 +- src/lerobot/motors/motors_bus.py | 101 ++- src/lerobot/scripts/lerobot_setup_can.py | 360 ++++++++++ src/lerobot/utils/import_utils.py | 1 + tests/motors/test_damiao.py | 66 ++ 14 files changed, 1740 insertions(+), 25 deletions(-) create mode 100644 docs/source/damiao.mdx create mode 100644 src/lerobot/motors/damiao/__init__.py create mode 100644 src/lerobot/motors/damiao/damiao.py create mode 100644 src/lerobot/motors/damiao/tables.py create mode 100644 src/lerobot/scripts/lerobot_setup_can.py create mode 100644 tests/motors/test_damiao.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 4298758b5..f86dd11c7 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -115,6 +115,8 @@ title: Notebooks - local: feetech title: Updating Feetech Firmware + - local: damiao + title: Damiao Motors and CAN Bus title: "Resources" - sections: - local: contributing diff --git a/docs/source/damiao.mdx b/docs/source/damiao.mdx new file mode 100644 index 000000000..45388ab9b --- /dev/null +++ b/docs/source/damiao.mdx @@ -0,0 +1,165 @@ +# Damiao Motors and CAN Bus + +This guide covers setup and usage of Damiao motors with LeRobot via CAN bus communication. + +Currently, only Linux is supported, as the OpenArms CAN adapter only has drivers for Linux. + +## Linux CAN Setup + +Before using Damiao motors, you need to set up the CAN interface on your Linux system. + +### Install CAN Utilities + +```bash +sudo apt-get install can-utils +``` + +### Configure CAN Interface (Manual) + +For standard CAN FD (recommended for OpenArms): + +```bash +sudo ip link set can0 down +sudo ip link set can0 type can bitrate 1000000 dbitrate 5000000 fd on +sudo ip link set can0 up +``` + +For standard CAN (without FD): + +```bash +sudo ip link set can0 down +sudo ip link set can0 type can bitrate 1000000 +sudo ip link set can0 up +``` + +### Configure CAN Interface (Using LeRobot) + +LeRobot provides a utility script to setup and test CAN interfaces: + +```bash +# Setup multiple interfaces (e.g., OpenArms Followers with 2 CAN buses) +lerobot-setup-can --mode=setup --interfaces=can0,can1 +``` + +## Debugging CAN Communication + +Use the built-in debug tools to test motor communication: + +```bash +# Test motors on all interfaces +lerobot-setup-can --mode=test --interfaces=can0,can1 + +# Run speed/latency test +lerobot-setup-can --mode=speed --interfaces=can0 +``` + +The test mode will scan for motors (IDs 0x01-0x08) and report which ones respond. Example output: + +``` +can0: UP (CAN FD) + Motor 0x01 (joint_1): ✓ FOUND + → Response 0x11 [FD]: 00112233... + Motor 0x02 (joint_2): ✓ FOUND + Motor 0x03 (joint_3): ✗ No response + ... + Summary: 2/8 motors found +``` + +## Usage + +### Basic Setup + +```python +from lerobot.motors import Motor +from lerobot.motors.damiao import DamiaoMotorsBus + +# Define your motors with send/receive CAN IDs +motors = { + "joint_1": Motor(id=0x01, motor_type_str="dm8009", recv_id=0x11), + "joint_2": Motor(id=0x02, motor_type_str="dm4340", recv_id=0x12), + "joint_3": Motor(id=0x03, motor_type_str="dm4310", recv_id=0x13), +} + +# Create the bus +bus = DamiaoMotorsBus( + port="can0", # Linux socketcan interface + motors=motors, +) + +# Connect +bus.connect() +``` + +### Reading Motor States + +```python +# Read single motor position (degrees) +position = bus.read("Present_Position", "joint_1") + +# Read from multiple motors +positions = bus.sync_read("Present_Position") # All motors +positions = bus.sync_read("Present_Position", ["joint_1", "joint_2"]) + +# Read all states at once (position, velocity, torque) +states = bus.sync_read_all_states() +# Returns: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...} +``` + +### Writing Motor Commands + +```python +# Enable torque +bus.enable_torque() + +# Set goal position (degrees) +bus.write("Goal_Position", "joint_1", 45.0) + +# Set positions for multiple motors +bus.sync_write("Goal_Position", { + "joint_1": 45.0, + "joint_2": -30.0, + "joint_3": 90.0, +}) + +# Disable torque +bus.disable_torque() +``` + +## Configuration Options + +| Parameter | Default | Description | +| -------------- | --------- | ----------------------------------------------------------- | +| `port` | - | CAN interface (`can0`) or serial port (`/dev/cu.usbmodem*`) | +| `use_can_fd` | `True` | Enable CAN FD for higher data rates | +| `bitrate` | `1000000` | Nominal bitrate (1 Mbps) | +| `data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) | + +## Motor Configuration + +Each motor requires: + +- `id`: CAN ID for sending commands +- `motor_type`: One of the supported motor types (e.g., `"dm8009"`, `"dm4340"`) +- `recv_id`: CAN ID for receiving responses + +OpenArms default IDs follow the pattern: send ID `0x0N`, receive ID `0x1N` where N is the joint number. + +## Troubleshooting + +### No Response from Motors + +1. **Check power** +2. **Verify CAN wiring**: Check CAN-H, CAN-L, and GND connections +3. **Check motor IDs**: Use Damiao Debugging Tools to verify/configure IDs +4. **Test CAN interface**: Run `candump can0` to see if messages are being received +5. **Run diagnostics**: `lerobot-setup-can --mode=test --interfaces=can0` + +### Motor Timeout Parameter + +If motors were configured with timeout=0, they won't respond to commands. Use Damiao Debugging Tools to set a non-zero timeout value. + +### Verify CAN FD Status + +```bash +ip -d link show can0 | grep fd +``` diff --git a/pyproject.toml b/pyproject.toml index 75f617e75..27126f855 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,7 @@ grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"] # Motors feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"] dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"] +damiao = ["python-can>=4.2.0,<5.0.0"] # Robots gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"] @@ -203,6 +204,7 @@ lerobot-info="lerobot.scripts.lerobot_info:main" lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main" lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main" lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main" +lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main" # ---------------- Tool Configurations ---------------- [tool.setuptools.packages.find] @@ -278,6 +280,7 @@ default.extend-ignore-identifiers-re = [ "thw", "inpt", "ROBOTIS", + "OT_VALUE" ] # TODO: Uncomment when ready to use diff --git a/src/lerobot/motors/__init__.py b/src/lerobot/motors/__init__.py index 850ef33d7..5df80d5ba 100644 --- a/src/lerobot/motors/__init__.py +++ b/src/lerobot/motors/__init__.py @@ -14,4 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus +from .motors_bus import ( + Motor, + MotorCalibration, + MotorNormMode, +) diff --git a/src/lerobot/motors/calibration_gui.py b/src/lerobot/motors/calibration_gui.py index 9832a1636..02bba454f 100644 --- a/src/lerobot/motors/calibration_gui.py +++ b/src/lerobot/motors/calibration_gui.py @@ -18,7 +18,7 @@ from dataclasses import dataclass os.environ["PYGAME_HIDE_SUPPORT_PROMPT"] = "1" -from lerobot.motors import MotorCalibration, MotorsBus +from .motors_bus import MotorCalibration, MotorsBus BAR_LEN, BAR_THICKNESS = 450, 8 HANDLE_R = 10 diff --git a/src/lerobot/motors/damiao/__init__.py b/src/lerobot/motors/damiao/__init__.py new file mode 100644 index 000000000..8240138cf --- /dev/null +++ b/src/lerobot/motors/damiao/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .damiao import DamiaoMotorsBus +from .tables import * diff --git a/src/lerobot/motors/damiao/damiao.py b/src/lerobot/motors/damiao/damiao.py new file mode 100644 index 000000000..dd0213fc3 --- /dev/null +++ b/src/lerobot/motors/damiao/damiao.py @@ -0,0 +1,808 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Portions of this file are derived from DM_Control_Python by cmjang. +# Licensed under the MIT License; see `LICENSE` for the full text: +# https://github.com/cmjang/DM_Control_Python + +import logging +import time +from contextlib import contextmanager +from copy import deepcopy +from functools import cached_property +from typing import TYPE_CHECKING, Any, TypedDict + +from lerobot.utils.import_utils import _can_available + +if TYPE_CHECKING or _can_available: + import can +else: + can.Message = object + can.interface = None + +import numpy as np + +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.robot_utils import precise_sleep +from lerobot.utils.utils import enter_pressed, move_cursor_up + +from ..motors_bus import Motor, MotorCalibration, MotorsBusBase, NameOrID, Value +from .tables import ( + AVAILABLE_BAUDRATES, + CAN_CMD_DISABLE, + CAN_CMD_ENABLE, + CAN_CMD_REFRESH, + CAN_CMD_SET_ZERO, + CAN_PARAM_ID, + DEFAULT_BAUDRATE, + DEFAULT_TIMEOUT_MS, + MIT_KD_RANGE, + MIT_KP_RANGE, + MOTOR_LIMIT_PARAMS, + MotorType, +) + +logger = logging.getLogger(__name__) + + +LONG_TIMEOUT_SEC = 0.1 +MEDIUM_TIMEOUT_SEC = 0.01 +SHORT_TIMEOUT_SEC = 0.001 +PRECISE_TIMEOUT_SEC = 0.0001 + + +class MotorState(TypedDict): + position: float + velocity: float + torque: float + temp_mos: float + temp_rotor: float + + +class DamiaoMotorsBus(MotorsBusBase): + """ + The Damiao implementation for a MotorsBus using CAN bus communication. + + This class uses python-can for CAN bus communication with Damiao motors. + For more info, see: + - python-can documentation: https://python-can.readthedocs.io/en/stable/ + - Seedstudio documentation: https://wiki.seeedstudio.com/damiao_series/ + - DM_Control_Python repo: https://github.com/cmjang/DM_Control_Python + """ + + # CAN-specific settings + available_baudrates = deepcopy(AVAILABLE_BAUDRATES) + default_baudrate = DEFAULT_BAUDRATE + default_timeout = DEFAULT_TIMEOUT_MS + + def __init__( + self, + port: str, + motors: dict[str, Motor], + calibration: dict[str, MotorCalibration] | None = None, + can_interface: str = "auto", + use_can_fd: bool = True, + bitrate: int = 1000000, + data_bitrate: int | None = 5000000, + ): + """ + Initialize the Damiao motors bus. + + Args: + port: CAN interface name (e.g., "can0" for Linux, "/dev/cu.usbmodem*" for macOS) + motors: Dictionary mapping motor names to Motor objects + calibration: Optional calibration data + can_interface: CAN interface type - "auto" (default), "socketcan" (Linux), or "slcan" (macOS/serial) + use_can_fd: Whether to use CAN FD mode (default: True for OpenArms) + bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps) + data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False + """ + super().__init__(port, motors, calibration) + self.port = port + self.can_interface = can_interface + self.use_can_fd = use_can_fd + self.bitrate = bitrate + self.data_bitrate = data_bitrate + self.canbus: can.interface.Bus | None = None + self._is_connected = False + + # Map motor names to CAN IDs + self._motor_can_ids: dict[str, int] = {} + self._recv_id_to_motor: dict[int, str] = {} + self._motor_types: dict[str, MotorType] = {} + + for name, motor in self.motors.items(): + if motor.motor_type_str is None: + raise ValueError(f"Motor '{name}' is missing required 'motor_type'") + self._motor_types[name] = getattr(MotorType, motor.motor_type_str.upper().replace("-", "_")) + + # Map recv_id to motor name for filtering responses + if motor.recv_id is not None: + self._recv_id_to_motor[motor.recv_id] = name + + # State cache for handling packet drops safely + self._last_known_states: dict[str, MotorState] = { + name: { + "position": 0.0, + "velocity": 0.0, + "torque": 0.0, + "temp_mos": 0.0, + "temp_rotor": 0.0, + } + for name in self.motors + } + + # Dynamic gains storage + # Defaults: Kp=10.0 (Stiffness), Kd=0.5 (Damping) + self._gains: dict[str, dict[str, float]] = {name: {"kp": 10.0, "kd": 0.5} for name in self.motors} + + @property + def is_connected(self) -> bool: + """Check if the CAN bus is connected.""" + return self._is_connected and self.canbus is not None + + def connect(self, handshake: bool = True) -> None: + """ + Open the CAN bus and initialize communication. + + Args: + handshake: If True, ping all motors to verify they're present + """ + if self.is_connected: + raise DeviceAlreadyConnectedError( + f"{self.__class__.__name__}('{self.port}') is already connected." + ) + + try: + # Auto-detect interface type based on port name + if self.can_interface == "auto": + if self.port.startswith("/dev/"): + self.can_interface = "slcan" + logger.info(f"Auto-detected slcan interface for port {self.port}") + else: + self.can_interface = "socketcan" + logger.info(f"Auto-detected socketcan interface for port {self.port}") + + # Connect to CAN bus + kwargs = { + "channel": self.port, + "bitrate": self.bitrate, + "interface": self.can_interface, + } + + if self.can_interface == "socketcan" and self.use_can_fd and self.data_bitrate is not None: + kwargs.update({"data_bitrate": self.data_bitrate, "fd": True}) + logger.info( + f"Connected to {self.port} with CAN FD (bitrate={self.bitrate}, data_bitrate={self.data_bitrate})" + ) + else: + logger.info(f"Connected to {self.port} with {self.can_interface} (bitrate={self.bitrate})") + + self.canbus = can.interface.Bus(**kwargs) + self._is_connected = True + + if handshake: + self._handshake() + + logger.debug(f"{self.__class__.__name__} connected via {self.can_interface}.") + except Exception as e: + self._is_connected = False + raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e + + def _handshake(self) -> None: + """ + Verify all motors are present and populate initial state cache. + Raises ConnectionError if any motor fails to respond. + """ + logger.info("Starting handshake with motors...") + missing_motors = [] + + for motor_name in self.motors: + msg = self._refresh_motor(motor_name) + if msg is None: + missing_motors.append(motor_name) + else: + self._process_response(motor_name, msg) + time.sleep(MEDIUM_TIMEOUT_SEC) + + if missing_motors: + raise ConnectionError( + f"Handshake failed. The following motors did not respond: {missing_motors}. " + "Check power (24V) and CAN wiring." + ) + logger.info("Handshake successful. All motors ready.") + + def disconnect(self, disable_torque: bool = True) -> None: + """ + Close the CAN bus connection. + + Args: + disable_torque: If True, disable torque on all motors before disconnecting + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.") + + if disable_torque: + try: + self.disable_torque() + except Exception as e: + logger.warning(f"Failed to disable torque during disconnect: {e}") + + if self.canbus: + self.canbus.shutdown() + self.canbus = None + self._is_connected = False + logger.debug(f"{self.__class__.__name__} disconnected.") + + def configure_motors(self) -> None: + """Configure all motors with default settings.""" + # Damiao motors don't require much configuration in MIT mode + # Just ensure they're enabled + for motor in self.motors: + self._send_simple_command(motor, CAN_CMD_ENABLE) + time.sleep(MEDIUM_TIMEOUT_SEC) + + def _send_simple_command(self, motor: NameOrID, command_byte: int) -> None: + """Helper to send simple 8-byte commands (Enable, Disable, Zero).""" + motor_id = self._get_motor_id(motor) + motor_name = self._get_motor_name(motor) + recv_id = self._get_motor_recv_id(motor) + data = [0xFF] * 7 + [command_byte] + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self.canbus.send(msg) + if msg := self._recv_motor_response(expected_recv_id=recv_id): + self._process_response(motor_name, msg) + else: + logger.debug(f"No response from {motor_name} after command 0x{command_byte:02X}") + + def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + """Enable torque on selected motors.""" + target_motors = self._get_motors_list(motors) + for motor in target_motors: + for _ in range(num_retry + 1): + try: + self._send_simple_command(motor, CAN_CMD_ENABLE) + break + except Exception as e: + if _ == num_retry: + raise e + time.sleep(MEDIUM_TIMEOUT_SEC) + + def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + """Disable torque on selected motors.""" + target_motors = self._get_motors_list(motors) + for motor in target_motors: + for _ in range(num_retry + 1): + try: + self._send_simple_command(motor, CAN_CMD_DISABLE) + break + except Exception as e: + if _ == num_retry: + raise e + time.sleep(MEDIUM_TIMEOUT_SEC) + + @contextmanager + def torque_disabled(self, motors: str | list[str] | None = None): + """ + Context manager that guarantees torque is re-enabled. + + This helper is useful to temporarily disable torque when configuring motors. + """ + self.disable_torque(motors) + try: + yield + finally: + self.enable_torque(motors) + + def set_zero_position(self, motors: str | list[str] | None = None) -> None: + """Set current position as zero for selected motors.""" + target_motors = self._get_motors_list(motors) + for motor in target_motors: + self._send_simple_command(motor, CAN_CMD_SET_ZERO) + time.sleep(MEDIUM_TIMEOUT_SEC) + + def _refresh_motor(self, motor: NameOrID) -> can.Message | None: + """Refresh motor status and return the response.""" + motor_id = self._get_motor_id(motor) + recv_id = self._get_motor_recv_id(motor) + data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0] + msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False) + self.canbus.send(msg) + return self._recv_motor_response(expected_recv_id=recv_id) + + def _recv_motor_response( + self, expected_recv_id: int | None = None, timeout: float = 0.001 + ) -> can.Message | None: + """ + Receive a response from a motor. + + Args: + expected_recv_id: If provided, only return messages from this CAN ID + timeout: Timeout in seconds (default: 1ms for high-speed operation) + Returns: + CAN message if received, None otherwise + """ + try: + start_time = time.time() + messages_seen = [] + while time.time() - start_time < timeout: + msg = self.canbus.recv(timeout=PRECISE_TIMEOUT_SEC) + if msg: + messages_seen.append(f"0x{msg.arbitration_id:02X}") + if expected_recv_id is None or msg.arbitration_id == expected_recv_id: + return msg + logger.debug( + f"Ignoring message from 0x{msg.arbitration_id:02X}, expected 0x{expected_recv_id:02X}" + ) + + if logger.isEnabledFor(logging.DEBUG): + if messages_seen: + logger.debug( + f"Received {len(messages_seen)} msgs from {set(messages_seen)}, expected 0x{expected_recv_id:02X}" + ) + else: + logger.debug(f"No CAN messages received (expected 0x{expected_recv_id:02X})") + except Exception as e: + logger.debug(f"Failed to receive CAN message: {e}") + return None + + def _recv_all_responses( + self, expected_recv_ids: list[int], timeout: float = 0.002 + ) -> dict[int, can.Message]: + """ + Efficiently receive responses from multiple motors at once. + Uses the OpenArms pattern: collect all available messages within timeout. + + Args: + expected_recv_ids: List of CAN IDs we expect responses from + timeout: Total timeout in seconds (default: 2ms) + + Returns: + Dictionary mapping recv_id to CAN message + """ + responses = {} + expected_set = set(expected_recv_ids) + start_time = time.time() + + try: + while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout: + # 100us poll timeout + msg = self.canbus.recv(timeout=PRECISE_TIMEOUT_SEC) + if msg and msg.arbitration_id in expected_set: + responses[msg.arbitration_id] = msg + if len(responses) == len(expected_recv_ids): + break + except Exception as e: + logger.debug(f"Error receiving responses: {e}") + + return responses + + def _encode_mit_packet( + self, + motor_type: MotorType, + kp: float, + kd: float, + position_degrees: float, + velocity_deg_per_sec: float, + torque: float, + ) -> list[int]: + """Helper to encode control parameters into 8 bytes for MIT mode.""" + # Convert degrees to radians + position_rad = np.radians(position_degrees) + velocity_rad_per_sec = np.radians(velocity_deg_per_sec) + + # Get motor limits + pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type] + + # Encode parameters + kp_uint = self._float_to_uint(kp, *MIT_KP_RANGE, 12) + kd_uint = self._float_to_uint(kd, *MIT_KD_RANGE, 12) + q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16) + dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12) + tau_uint = self._float_to_uint(torque, -tmax, tmax, 12) + + # Pack data + data = [0] * 8 + data[0] = (q_uint >> 8) & 0xFF + data[1] = q_uint & 0xFF + data[2] = dq_uint >> 4 + data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF) + data[4] = kp_uint & 0xFF + data[5] = kd_uint >> 4 + data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF) + data[7] = tau_uint & 0xFF + return data + + def _mit_control( + self, + motor: NameOrID, + kp: float, + kd: float, + position_degrees: float, + velocity_deg_per_sec: float, + torque: float, + ) -> None: + """Send MIT control command to a motor.""" + motor_id = self._get_motor_id(motor) + motor_name = self._get_motor_name(motor) + motor_type = self._motor_types[motor_name] + + data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self.canbus.send(msg) + + recv_id = self._get_motor_recv_id(motor) + if msg := self._recv_motor_response(expected_recv_id=recv_id): + self._process_response(motor_name, msg) + else: + logger.debug(f"No response from {motor_name} after MIT control command") + + def _mit_control_batch( + self, + commands: dict[NameOrID, tuple[float, float, float, float, float]], + ) -> None: + """ + Send MIT control commands to multiple motors in batch. + Sends all commands first, then collects responses. + + Args: + commands: Dict mapping motor name/ID to (kp, kd, position_deg, velocity_deg/s, torque) + Example: {'joint_1': (10.0, 0.5, 45.0, 0.0, 0.0), ...} + """ + if not commands: + return + + recv_id_to_motor: dict[int, str] = {} + + # Step 1: Send all MIT control commands + for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items(): + motor_id = self._get_motor_id(motor) + motor_name = self._get_motor_name(motor) + motor_type = self._motor_types[motor_name] + + data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self.canbus.send(msg) + + recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name + + # Step 2: Collect responses and update state cache + responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=SHORT_TIMEOUT_SEC) + for recv_id, motor_name in recv_id_to_motor.items(): + if msg := responses.get(recv_id): + self._process_response(motor_name, msg) + + def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int: + """Convert float to unsigned integer for CAN transmission.""" + x = max(x_min, min(x_max, x)) # Clamp to range + span = x_max - x_min + data_norm = (x - x_min) / span + return int(data_norm * ((1 << bits) - 1)) + + def _uint_to_float(self, x: int, x_min: float, x_max: float, bits: int) -> float: + """Convert unsigned integer from CAN to float.""" + span = x_max - x_min + data_norm = float(x) / ((1 << bits) - 1) + return data_norm * span + x_min + + def _decode_motor_state( + self, data: bytearray | bytes, motor_type: MotorType + ) -> tuple[float, float, float, int, int]: + """ + Decode motor state from CAN data. + Returns: (position_deg, velocity_deg_s, torque, temp_mos, temp_rotor) + """ + if len(data) < 8: + raise ValueError("Invalid motor state data") + + # Extract encoded values + q_uint = (data[1] << 8) | data[2] + dq_uint = (data[3] << 4) | (data[4] >> 4) + tau_uint = ((data[4] & 0x0F) << 8) | data[5] + t_mos = data[6] + t_rotor = data[7] + + # Get motor limits + pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type] + + # Decode to physical values + position_rad = self._uint_to_float(q_uint, -pmax, pmax, 16) + velocity_rad_per_sec = self._uint_to_float(dq_uint, -vmax, vmax, 12) + torque = self._uint_to_float(tau_uint, -tmax, tmax, 12) + + return np.degrees(position_rad), np.degrees(velocity_rad_per_sec), torque, t_mos, t_rotor + + def _process_response(self, motor: str, msg: can.Message) -> None: + """Decode a message and update the motor state cache.""" + try: + motor_type = self._motor_types[motor] + pos, vel, torque, t_mos, t_rotor = self._decode_motor_state(msg.data, motor_type) + + self._last_known_states[motor] = { + "position": pos, + "velocity": vel, + "torque": torque, + "temp_mos": float(t_mos), + "temp_rotor": float(t_rotor), + } + except Exception as e: + logger.warning(f"Failed to decode response from {motor}: {e}") + + def read(self, data_name: str, motor: str) -> Value: + """Read a value from a single motor. Positions are always in degrees.""" + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + # Refresh motor to get latest state + msg = self._refresh_motor(motor) + if msg is None: + motor_id = self._get_motor_id(motor) + recv_id = self._get_motor_recv_id(motor) + raise ConnectionError( + f"No response from motor '{motor}' (send ID: 0x{motor_id:02X}, recv ID: 0x{recv_id:02X}). " + f"Check that: 1) Motor is powered (24V), 2) CAN wiring is correct, " + f"3) Motor IDs are configured correctly using Damiao Debugging Tools" + ) + + self._process_response(motor, msg) + return self._get_cached_value(motor, data_name) + + def _get_cached_value(self, motor: str, data_name: str) -> Value: + """Retrieve a specific value from the cache.""" + state = self._last_known_states[motor] + mapping: dict[str, Any] = { + "Present_Position": state["position"], + "Present_Velocity": state["velocity"], + "Present_Torque": state["torque"], + "Temperature_MOS": state["temp_mos"], + "Temperature_Rotor": state["temp_rotor"], + } + if data_name not in mapping: + raise ValueError(f"Unknown data_name: {data_name}") + return mapping[data_name] + + def write( + self, + data_name: str, + motor: str, + value: Value, + ) -> None: + """ + Write a value to a single motor. Positions are always in degrees. + Can write 'Goal_Position', 'Kp', or 'Kd'. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if data_name in ("Kp", "Kd"): + self._gains[motor][data_name.lower()] = float(value) + elif data_name == "Goal_Position": + kp = self._gains[motor]["kp"] + kd = self._gains[motor]["kd"] + self._mit_control(motor, kp, kd, float(value), 0.0, 0.0) + else: + raise ValueError(f"Writing {data_name} not supported in MIT mode") + + def sync_read( + self, + data_name: str, + motors: str | list[str] | None = None, + ) -> dict[str, Value]: + """ + Read the same value from multiple motors simultaneously. + """ + target_motors = self._get_motors_list(motors) + self._batch_refresh(target_motors) + + result = {} + for motor in target_motors: + result[motor] = self._get_cached_value(motor, data_name) + return result + + def sync_read_all_states( + self, + motors: str | list[str] | None = None, + *, + num_retry: int = 0, + ) -> dict[str, MotorState]: + """ + Read ALL motor states (position, velocity, torque) from multiple motors in ONE refresh cycle. + + Returns: + Dictionary mapping motor names to state dicts with keys: 'position', 'velocity', 'torque' + Example: {'joint_1': {'position': 45.2, 'velocity': 1.3, 'torque': 0.5}, ...} + """ + target_motors = self._get_motors_list(motors) + self._batch_refresh(target_motors) + + result = {} + for motor in target_motors: + result[motor] = self._last_known_states[motor].copy() + return result + + def _batch_refresh(self, motors: list[str]) -> None: + """Internal helper to refresh a list of motors and update cache.""" + # Send refresh commands + for motor in motors: + motor_id = self._get_motor_id(motor) + data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0] + msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False) + self.canbus.send(msg) + # Small delay to reduce bus congestion if necessary, though removed in sync_read previously + # precise_sleep(PRECISE_SLEEP_SEC) + + # Collect responses + expected_recv_ids = [self._get_motor_recv_id(m) for m in motors] + responses = self._recv_all_responses(expected_recv_ids, timeout=MEDIUM_TIMEOUT_SEC) + + # Update cache + for motor in motors: + recv_id = self._get_motor_recv_id(motor) + msg = responses.get(recv_id) + if msg: + self._process_response(motor, msg) + else: + logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.") + + def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None: + """ + Write values to multiple motors simultaneously. Positions are always in degrees. + """ + if data_name in ("Kp", "Kd"): + key = data_name.lower() + for motor, val in values.items(): + self._gains[motor][key] = float(val) + + elif data_name == "Goal_Position": + # Step 1: Send all MIT control commands + recv_id_to_motor: dict[int, str] = {} + for motor, value_degrees in values.items(): + motor_id = self._get_motor_id(motor) + motor_name = self._get_motor_name(motor) + motor_type = self._motor_types[motor_name] + + kp = self._gains[motor]["kp"] + kd = self._gains[motor]["kd"] + + data = self._encode_mit_packet(motor_type, kp, kd, float(value_degrees), 0.0, 0.0) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self.canbus.send(msg) + precise_sleep(PRECISE_TIMEOUT_SEC) + + recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name + + # Step 2: Collect responses and update state cache + responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=MEDIUM_TIMEOUT_SEC) + for recv_id, motor_name in recv_id_to_motor.items(): + if msg := responses.get(recv_id): + self._process_response(motor_name, msg) + else: + # Fall back to individual writes + for motor, value in values.items(): + self.write(data_name, motor, value) + + def read_calibration(self) -> dict[str, MotorCalibration]: + """Read calibration data from motors.""" + # Damiao motors don't store calibration internally + # Return existing calibration or empty dict + return self.calibration if self.calibration else {} + + def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None: + """Write calibration data to motors.""" + # Damiao motors don't store calibration internally + # Just cache it in memory + if cache: + self.calibration = calibration_dict + + def record_ranges_of_motion( + self, + motors: NameOrID | list[NameOrID] | None = None, + display_values: bool = True, + ) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: + """ + Interactively record the min/max values of each motor in degrees. + + Move the joints by hand (with torque disabled) while the method streams live positions. + Press Enter to finish. + """ + target_motors = self._get_motors_list(motors) + + self.disable_torque(target_motors) + time.sleep(LONG_TIMEOUT_SEC) + + start_positions = self.sync_read("Present_Position", target_motors) + mins = start_positions.copy() + maxes = start_positions.copy() + + print("\nMove joints through their full range of motion. Press ENTER when done.") + user_pressed_enter = False + + while not user_pressed_enter: + positions = self.sync_read("Present_Position", target_motors) + + for motor in target_motors: + if motor in positions: + mins[motor] = min(positions[motor], mins.get(motor, positions[motor])) + maxes[motor] = max(positions[motor], maxes.get(motor, positions[motor])) + + if display_values: + print("\n" + "=" * 50) + print(f"{'MOTOR':<20} | {'MIN (deg)':>12} | {'POS (deg)':>12} | {'MAX (deg)':>12}") + print("-" * 50) + for motor in target_motors: + if motor in positions: + print( + f"{motor:<20} | {mins[motor]:>12.1f} | {positions[motor]:>12.1f} | {maxes[motor]:>12.1f}" + ) + + if enter_pressed(): + user_pressed_enter = True + + if display_values and not user_pressed_enter: + move_cursor_up(len(target_motors) + 4) + + time.sleep(LONG_TIMEOUT_SEC) + + self.enable_torque(target_motors) + + for motor in target_motors: + if (motor in mins) and (motor in maxes) and (int(abs(maxes[motor] - mins[motor])) < 5): + raise ValueError(f"Motor {motor} has insufficient range of motion (< 5 degrees)") + + return mins, maxes + + def _get_motors_list(self, motors: str | list[str] | None) -> list[str]: + """Convert motor specification to list of motor names.""" + if motors is None: + return list(self.motors.keys()) + elif isinstance(motors, str): + return [motors] + elif isinstance(motors, list): + return motors + else: + raise TypeError(f"Invalid motors type: {type(motors)}") + + def _get_motor_id(self, motor: NameOrID) -> int: + """Get CAN ID for a motor.""" + if isinstance(motor, str): + if motor in self.motors: + return self.motors[motor].id + else: + raise ValueError(f"Unknown motor: {motor}") + else: + return motor + + def _get_motor_name(self, motor: NameOrID) -> str: + """Get motor name from name or ID.""" + if isinstance(motor, str): + return motor + else: + for name, m in self.motors.items(): + if m.id == motor: + return name + raise ValueError(f"Unknown motor ID: {motor}") + + def _get_motor_recv_id(self, motor: NameOrID) -> int: + """Get motor recv_id from name or ID.""" + motor_name = self._get_motor_name(motor) + motor_obj = self.motors.get(motor_name) + if motor_obj and motor_obj.recv_id is not None: + return motor_obj.recv_id + else: + raise ValueError(f"Motor {motor_obj} doesn't have a valid recv_id (None).") + + @cached_property + def is_calibrated(self) -> bool: + """Check if motors are calibrated.""" + return bool(self.calibration) diff --git a/src/lerobot/motors/damiao/tables.py b/src/lerobot/motors/damiao/tables.py new file mode 100644 index 000000000..22d1624fa --- /dev/null +++ b/src/lerobot/motors/damiao/tables.py @@ -0,0 +1,209 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration tables for Damiao motors.""" + +from enum import IntEnum + + +# Motor type definitions +class MotorType(IntEnum): + DM3507 = 0 + DM4310 = 1 + DM4310_48V = 2 + DM4340 = 3 + DM4340_48V = 4 + DM6006 = 5 + DM8006 = 6 + DM8009 = 7 + DM10010L = 8 + DM10010 = 9 + DMH3510 = 10 + DMH6215 = 11 + DMG6220 = 12 + + +# Control modes +class ControlMode(IntEnum): + MIT = 1 + POS_VEL = 2 + VEL = 3 + TORQUE_POS = 4 + + +# Motor variable IDs (RID) +class MotorVariable(IntEnum): + UV_VALUE = 0 + KT_VALUE = 1 + OT_VALUE = 2 + OC_VALUE = 3 + ACC = 4 + DEC = 5 + MAX_SPD = 6 + MST_ID = 7 + ESC_ID = 8 + TIMEOUT = 9 + CTRL_MODE = 10 + DAMP = 11 + INERTIA = 12 + HW_VER = 13 + SW_VER = 14 + SN = 15 + NPP = 16 + RS = 17 + LS = 18 + FLUX = 19 + GR = 20 + PMAX = 21 + VMAX = 22 + TMAX = 23 + I_BW = 24 + KP_ASR = 25 + KI_ASR = 26 + KP_APR = 27 + KI_APR = 28 + OV_VALUE = 29 + GREF = 30 + DETA = 31 + V_BW = 32 + IQ_C1 = 33 + VL_C1 = 34 + CAN_BR = 35 + SUB_VER = 36 + U_OFF = 50 + V_OFF = 51 + K1 = 52 + K2 = 53 + M_OFF = 54 + DIR = 55 + P_M = 80 + XOUT = 81 + + +# Motor limit parameters [PMAX, VMAX, TMAX] +# PMAX: Maximum position (rad) +# VMAX: Maximum velocity (rad/s) +# TMAX: Maximum torque (N·m) +MOTOR_LIMIT_PARAMS = { + MotorType.DM3507: (12.5, 30, 10), + MotorType.DM4310: (12.5, 30, 10), + MotorType.DM4310_48V: (12.5, 50, 10), + MotorType.DM4340: (12.5, 8, 28), + MotorType.DM4340_48V: (12.5, 10, 28), + MotorType.DM6006: (12.5, 45, 20), + MotorType.DM8006: (12.5, 45, 40), + MotorType.DM8009: (12.5, 45, 54), + MotorType.DM10010L: (12.5, 25, 200), + MotorType.DM10010: (12.5, 20, 200), + MotorType.DMH3510: (12.5, 280, 1), + MotorType.DMH6215: (12.5, 45, 10), + MotorType.DMG6220: (12.5, 45, 10), +} + +# Motor model names +MODEL_NAMES = { + MotorType.DM3507: "dm3507", + MotorType.DM4310: "dm4310", + MotorType.DM4310_48V: "dm4310_48v", + MotorType.DM4340: "dm4340", + MotorType.DM4340_48V: "dm4340_48v", + MotorType.DM6006: "dm6006", + MotorType.DM8006: "dm8006", + MotorType.DM8009: "dm8009", + MotorType.DM10010L: "dm10010l", + MotorType.DM10010: "dm10010", + MotorType.DMH3510: "dmh3510", + MotorType.DMH6215: "dmh6215", + MotorType.DMG6220: "dmg6220", +} + +# Motor resolution table (encoder counts per revolution) +MODEL_RESOLUTION = { + "dm3507": 65536, + "dm4310": 65536, + "dm4310_48v": 65536, + "dm4340": 65536, + "dm4340_48v": 65536, + "dm6006": 65536, + "dm8006": 65536, + "dm8009": 65536, + "dm10010l": 65536, + "dm10010": 65536, + "dmh3510": 65536, + "dmh6215": 65536, + "dmg6220": 65536, +} + +# CAN baudrates supported by Damiao motors +AVAILABLE_BAUDRATES = [ + 125000, # 0: 125 kbps + 200000, # 1: 200 kbps + 250000, # 2: 250 kbps + 500000, # 3: 500 kbps + 1000000, # 4: 1 mbps (default for OpenArms) + 2000000, # 5: 2 mbps + 2500000, # 6: 2.5 mbps + 3200000, # 7: 3.2 mbps + 4000000, # 8: 4 mbps + 5000000, # 9: 5 mbps +] +DEFAULT_BAUDRATE = 1000000 # 1 Mbps is standard for OpenArms + +# Default timeout in milliseconds +DEFAULT_TIMEOUT_MS = 1000 + +# OpenArms specific configurations +# Based on: https://docs.openarm.dev/software/setup/configure-test +# OpenArms has 7 DOF per arm (14 total for dual arm) +OPENARMS_ARM_MOTOR_IDS = { + "joint_1": {"send": 0x01, "recv": 0x11}, # J1 - Shoulder pan + "joint_2": {"send": 0x02, "recv": 0x12}, # J2 - Shoulder lift + "joint_3": {"send": 0x03, "recv": 0x13}, # J3 - Elbow flex + "joint_4": {"send": 0x04, "recv": 0x14}, # J4 - Wrist flex + "joint_5": {"send": 0x05, "recv": 0x15}, # J5 - Wrist roll + "joint_6": {"send": 0x06, "recv": 0x16}, # J6 - Wrist pitch + "joint_7": {"send": 0x07, "recv": 0x17}, # J7 - Wrist rotation +} + +OPENARMS_GRIPPER_MOTOR_IDS = { + "gripper": {"send": 0x08, "recv": 0x18}, # J8 - Gripper +} + +# Default motor types for OpenArms +OPENARMS_DEFAULT_MOTOR_TYPES = { + "joint_1": MotorType.DM8009, # Shoulder pan - high torque + "joint_2": MotorType.DM8009, # Shoulder lift - high torque + "joint_3": MotorType.DM4340, # Shoulder rotation + "joint_4": MotorType.DM4340, # Elbow flex + "joint_5": MotorType.DM4310, # Wrist roll + "joint_6": MotorType.DM4310, # Wrist pitch + "joint_7": MotorType.DM4310, # Wrist rotation + "gripper": MotorType.DM4310, # Gripper +} + +# MIT control parameter ranges +MIT_KP_RANGE = (0.0, 500.0) +MIT_KD_RANGE = (0.0, 5.0) + +# CAN frame command IDs +CAN_CMD_ENABLE = 0xFC +CAN_CMD_DISABLE = 0xFD +CAN_CMD_SET_ZERO = 0xFE +CAN_CMD_REFRESH = 0xCC +CAN_CMD_QUERY_PARAM = 0x33 +CAN_CMD_WRITE_PARAM = 0x55 +CAN_CMD_SAVE_PARAM = 0xAA + +# CAN ID for parameter operations +CAN_PARAM_ID = 0x7FF diff --git a/src/lerobot/motors/dynamixel/dynamixel.py b/src/lerobot/motors/dynamixel/dynamixel.py index 01bfcf544..c6752ee96 100644 --- a/src/lerobot/motors/dynamixel/dynamixel.py +++ b/src/lerobot/motors/dynamixel/dynamixel.py @@ -22,9 +22,8 @@ import logging from copy import deepcopy from enum import Enum -from lerobot.motors.encoding_utils import decode_twos_complement, encode_twos_complement - -from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address +from ..encoding_utils import decode_twos_complement, encode_twos_complement +from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address from .tables import ( AVAILABLE_BAUDRATES, MODEL_BAUDRATE_TABLE, @@ -100,7 +99,7 @@ def _split_into_byte_chunks(value: int, length: int) -> list[int]: return data -class DynamixelMotorsBus(MotorsBus): +class DynamixelMotorsBus(SerialMotorsBus): """ The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with the motors. For more info, see the Dynamixel SDK Documentation: @@ -203,9 +202,9 @@ class DynamixelMotorsBus(MotorsBus): for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry) - def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None: + def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None: addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable") - self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry) + self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry) def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): diff --git a/src/lerobot/motors/feetech/feetech.py b/src/lerobot/motors/feetech/feetech.py index 2ea57af12..7ce3388b6 100644 --- a/src/lerobot/motors/feetech/feetech.py +++ b/src/lerobot/motors/feetech/feetech.py @@ -17,9 +17,8 @@ from copy import deepcopy from enum import Enum from pprint import pformat -from lerobot.motors.encoding_utils import decode_sign_magnitude, encode_sign_magnitude - -from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address +from ..encoding_utils import decode_sign_magnitude, encode_sign_magnitude +from ..motors_bus import Motor, MotorCalibration, NameOrID, SerialMotorsBus, Value, get_address from .tables import ( FIRMWARE_MAJOR_VERSION, FIRMWARE_MINOR_VERSION, @@ -96,7 +95,7 @@ def patch_setPacketTimeout(self, packet_length): # noqa: N802 self.packet_timeout = (self.tx_time_per_byte * packet_length) + (self.tx_time_per_byte * 3.0) + 50 -class FeetechMotorsBus(MotorsBus): +class FeetechMotorsBus(SerialMotorsBus): """ The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on the python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk. @@ -298,11 +297,11 @@ class FeetechMotorsBus(MotorsBus): self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry) self.write("Lock", motor, 0, num_retry=num_retry) - def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None: + def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None: addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable") - self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry) + self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry) addr, length = get_address(self.model_ctrl_table, model, "Lock") - self._write(addr, length, motor_id, 0, num_retry=num_retry) + self._write(addr, length, motor, 0, num_retry=num_retry) def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py index 91bee994a..c04f718b6 100644 --- a/src/lerobot/motors/motors_bus.py +++ b/src/lerobot/motors/motors_bus.py @@ -19,6 +19,8 @@ # TODO(aliberts): Add block noqa when feature below is available # https://github.com/astral-sh/ruff/issues/3711 +from __future__ import annotations + import abc import logging from contextlib import contextmanager @@ -41,6 +43,81 @@ Value: TypeAlias = int | float logger = logging.getLogger(__name__) +class MotorsBusBase(abc.ABC): + """ + Base class for all motor bus implementations. + + This is a minimal interface that all motor buses must implement, regardless of their + communication protocol (serial, CAN, etc.). + """ + + def __init__( + self, + port: str, + motors: dict[str, Motor], + calibration: dict[str, MotorCalibration] | None = None, + ): + self.port = port + self.motors = motors + self.calibration = calibration if calibration else {} + + @abc.abstractmethod + def connect(self, handshake: bool = True) -> None: + """Establish connection to the motors.""" + pass + + @abc.abstractmethod + def disconnect(self, disable_torque: bool = True) -> None: + """Disconnect from the motors.""" + pass + + @property + @abc.abstractmethod + def is_connected(self) -> bool: + """Check if connected to the motors.""" + pass + + @abc.abstractmethod + def read(self, data_name: str, motor: str) -> Value: + """Read a value from a single motor.""" + pass + + @abc.abstractmethod + def write(self, data_name: str, motor: str, value: Value) -> None: + """Write a value to a single motor.""" + pass + + @abc.abstractmethod + def sync_read(self, data_name: str, motors: str | list[str] | None = None) -> dict[str, Value]: + """Read a value from multiple motors.""" + pass + + @abc.abstractmethod + def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None: + """Write values to multiple motors.""" + pass + + @abc.abstractmethod + def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + """Enable torque on selected motors.""" + pass + + @abc.abstractmethod + def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + """Disable torque on selected motors.""" + pass + + @abc.abstractmethod + def read_calibration(self) -> dict[str, MotorCalibration]: + """Read calibration parameters from the motors.""" + pass + + @abc.abstractmethod + def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None: + """Write calibration parameters to the motors.""" + pass + + def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]: ctrl_table = model_ctrl_table.get(model) if ctrl_table is None: @@ -97,6 +174,8 @@ class Motor: id: int model: str norm_mode: MotorNormMode + motor_type_str: str | None = None + recv_id: int | None = None class PortHandler(Protocol): @@ -203,15 +282,15 @@ class GroupSyncWrite(Protocol): def txPacket(self): ... -class MotorsBus(abc.ABC): +class SerialMotorsBus(MotorsBusBase): """ - A MotorsBus allows to efficiently read and write to the attached motors. + A SerialMotorsBus allows to efficiently read and write to motors connected via serial communication. It represents several motors daisy-chained together and connected through a serial port. - There are currently two implementations of this abstract class: + There are currently two implementations of this class: - DynamixelMotorsBus - FeetechMotorsBus - Note: This class may evolve in the future should we add support for other types of bus. + This class is specifically for serial-based motor protocols (Dynamixel, Feetech, etc.). A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)). To find the port, you can run our utility script: @@ -260,9 +339,7 @@ class MotorsBus(abc.ABC): motors: dict[str, Motor], calibration: dict[str, MotorCalibration] | None = None, ): - self.port = port - self.motors = motors - self.calibration = calibration if calibration else {} + super().__init__(port, motors, calibration) self.port_handler: PortHandler self.packet_handler: PacketHandler @@ -532,7 +609,7 @@ class MotorsBus(abc.ABC): self.set_baudrate(self.default_baudrate) @abc.abstractmethod - def _find_single_motor(self, motor: str, initial_baudrate: int | None) -> tuple[int, int]: + def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]: pass @abc.abstractmethod @@ -545,13 +622,13 @@ class MotorsBus(abc.ABC): pass @abc.abstractmethod - def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: + def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: """Disable torque on selected motors. Disabling Torque allows to write to the motors' permanent memory area (EPROM/EEPROM). Args: - motors (int | str | list[str] | None, optional): Target motors. Accepts a motor name, an ID, a + motors ( str | list[str] | None, optional): Target motors. Accepts a motor name, an ID, a list of names or `None` to affect every registered motor. Defaults to `None`. num_retry (int, optional): Number of additional retry attempts on communication failure. Defaults to 0. @@ -1194,3 +1271,7 @@ class MotorsBus(abc.ABC): for id_, value in ids_values.items(): data = self._serialize_data(value, length) self.sync_writer.addParam(id_, data) + + +# Backward compatibility alias +MotorsBus: TypeAlias = SerialMotorsBus diff --git a/src/lerobot/scripts/lerobot_setup_can.py b/src/lerobot/scripts/lerobot_setup_can.py new file mode 100644 index 000000000..55de74724 --- /dev/null +++ b/src/lerobot/scripts/lerobot_setup_can.py @@ -0,0 +1,360 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Setup and debug CAN interfaces for Damiao motors (e.g., OpenArms). + +Examples: + +Setup CAN interfaces with CAN FD: +```shell +lerobot-setup-can --mode=setup --interfaces=can0,can1,can2,can3 +``` + +Test motors on a single interface: +```shell +lerobot-setup-can --mode=test --interfaces=can0 +``` + +Test motors on all interfaces: +```shell +lerobot-setup-can --mode=test --interfaces=can0,can1,can2,can3 +``` + +Speed test: +```shell +lerobot-setup-can --mode=speed --interfaces=can0 +``` +""" + +import subprocess +import sys +import time +from dataclasses import dataclass, field + +import draccus + +from lerobot.utils.import_utils import is_package_available + +MOTOR_NAMES = { + 0x01: "joint_1", + 0x02: "joint_2", + 0x03: "joint_3", + 0x04: "joint_4", + 0x05: "joint_5", + 0x06: "joint_6", + 0x07: "joint_7", + 0x08: "gripper", +} + + +@dataclass +class CANSetupConfig: + mode: str = "test" + interfaces: str = "can0" # Comma-separated, e.g. "can0,can1,can2,can3" + bitrate: int = 1000000 + data_bitrate: int = 5000000 + use_fd: bool = True + motor_ids: list[int] = field(default_factory=lambda: list(range(0x01, 0x09))) + timeout: float = 1.0 + speed_iterations: int = 100 + + def get_interfaces(self) -> list[str]: + return [i.strip() for i in self.interfaces.split(",") if i.strip()] + + +def check_interface_status(interface: str) -> tuple[bool, str, bool]: + """Check if CAN interface is UP and configured.""" + try: + result = subprocess.run(["ip", "link", "show", interface], capture_output=True, text=True) # nosec B607 + if result.returncode != 0: + return False, "Interface not found", False + + output = result.stdout + is_up = "UP" in output + is_fd = "fd on" in output.lower() or "canfd" in output.lower() + status = "UP" if is_up else "DOWN" + if is_fd: + status += " (CAN FD)" + + return is_up, status, is_fd + except FileNotFoundError: + return False, "ip command not found", False + + +def setup_interface(interface: str, bitrate: int, data_bitrate: int, use_fd: bool) -> bool: + """Configure a CAN interface.""" + try: + subprocess.run(["sudo", "ip", "link", "set", interface, "down"], check=False, capture_output=True) # nosec B607 + + cmd = ["sudo", "ip", "link", "set", interface, "type", "can", "bitrate", str(bitrate)] + if use_fd: + cmd.extend(["dbitrate", str(data_bitrate), "fd", "on"]) + + result = subprocess.run(cmd, capture_output=True, text=True) # nosec B607 + if result.returncode != 0: + print(f" ✗ Failed to configure: {result.stderr}") + return False + + result = subprocess.run( # nosec B607 + ["sudo", "ip", "link", "set", interface, "up"], capture_output=True, text=True + ) + if result.returncode != 0: + print(f" ✗ Failed to bring up: {result.stderr}") + return False + + return True + except Exception as e: + print(f" ✗ Error: {e}") + return False + + +def test_motor(bus, motor_id: int, timeout: float, use_fd: bool): + """Test a single motor and return responses.""" + import can + + enable_msg = can.Message( + arbitration_id=motor_id, + data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC], + is_extended_id=False, + is_fd=use_fd, + ) + + try: + bus.send(enable_msg) + except Exception as e: + return None, f"Send error: {e}" + + responses = [] + start_time = time.time() + + while time.time() - start_time < timeout: + msg = bus.recv(timeout=0.1) + if msg: + responses.append((msg.arbitration_id, msg.data.hex(), getattr(msg, "is_fd", False))) + + disable_msg = can.Message( + arbitration_id=motor_id, + data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFD], + is_extended_id=False, + is_fd=use_fd, + ) + try: + bus.send(disable_msg) + except Exception: + print(f"Error sending message to motor 0x{motor_id:02X}") + + return responses, None + + +def test_interface(cfg: CANSetupConfig, interface: str): + """Test all motors on a CAN interface.""" + import can + + is_up, status, _ = check_interface_status(interface) + print(f"\n{interface}: {status}") + + if not is_up: + print(f" ⚠ Interface is not UP. Run: lerobot-setup-can --mode=setup --interfaces {interface}") + return {} + + try: + kwargs = {"channel": interface, "interface": "socketcan", "bitrate": cfg.bitrate} + if cfg.use_fd: + kwargs.update({"data_bitrate": cfg.data_bitrate, "fd": True}) + bus = can.interface.Bus(**kwargs) + except Exception as e: + print(f" ✗ Connection failed: {e}") + return {} + + results = {} + try: + while bus.recv(timeout=0.01): + pass + + for motor_id in cfg.motor_ids: + motor_name = MOTOR_NAMES.get(motor_id, f"motor_0x{motor_id:02X}") + responses, error = test_motor(bus, motor_id, cfg.timeout, cfg.use_fd) + + if error: + print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ {error}") + results[motor_id] = {"found": False, "error": error} + elif responses: + print(f" Motor 0x{motor_id:02X} ({motor_name}): ✓ FOUND") + for resp_id, data, is_fd in responses: + fd_flag = " [FD]" if is_fd else "" + print(f" → Response 0x{resp_id:02X}{fd_flag}: {data}") + results[motor_id] = {"found": True, "responses": responses} + else: + print(f" Motor 0x{motor_id:02X} ({motor_name}): ✗ No response") + results[motor_id] = {"found": False} + + time.sleep(0.05) + finally: + bus.shutdown() + + found = sum(1 for r in results.values() if r.get("found")) + print(f"\n Summary: {found}/{len(cfg.motor_ids)} motors found") + return results + + +def speed_test(cfg: CANSetupConfig, interface: str): + """Test communication speed with motors.""" + import can + + is_up, status, _ = check_interface_status(interface) + if not is_up: + print(f"{interface}: {status} - skipping") + return + + print(f"\n{interface}: Running speed test ({cfg.speed_iterations} iterations)...") + + try: + kwargs = {"channel": interface, "interface": "socketcan", "bitrate": cfg.bitrate} + if cfg.use_fd: + kwargs.update({"data_bitrate": cfg.data_bitrate, "fd": True}) + bus = can.interface.Bus(**kwargs) + except Exception as e: + print(f" ✗ Connection failed: {e}") + return + + responding_motor = None + for motor_id in cfg.motor_ids: + responses, _ = test_motor(bus, motor_id, 0.5, cfg.use_fd) + if responses: + responding_motor = motor_id + break + + if not responding_motor: + print(" ✗ No responding motors found") + bus.shutdown() + return + + print(f" Testing with motor 0x{responding_motor:02X}...") + latencies = [] + + for _ in range(cfg.speed_iterations): + start = time.perf_counter() + msg = can.Message( + arbitration_id=responding_motor, + data=[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFC], + is_extended_id=False, + is_fd=cfg.use_fd, + ) + bus.send(msg) + resp = bus.recv(timeout=0.1) + if resp: + latencies.append((time.perf_counter() - start) * 1000) + + bus.shutdown() + + if latencies: + avg_latency = sum(latencies) / len(latencies) + hz = 1000.0 / avg_latency if avg_latency > 0 else 0 + print(f" ✓ Success rate: {len(latencies)}/{cfg.speed_iterations}") + print(f" ✓ Avg latency: {avg_latency:.2f} ms") + print(f" ✓ Max frequency: {hz:.1f} Hz") + else: + print(" ✗ No successful responses") + + +def run_setup(cfg: CANSetupConfig): + """Setup CAN interfaces.""" + print("=" * 50) + print("CAN Interface Setup") + print("=" * 50) + print(f"Mode: {'CAN FD' if cfg.use_fd else 'CAN 2.0'}") + print(f"Bitrate: {cfg.bitrate / 1_000_000:.1f} Mbps") + if cfg.use_fd: + print(f"Data bitrate: {cfg.data_bitrate / 1_000_000:.1f} Mbps") + print() + + interfaces = cfg.get_interfaces() + for interface in interfaces: + print(f"Configuring {interface}...") + if setup_interface(interface, cfg.bitrate, cfg.data_bitrate, cfg.use_fd): + is_up, status, _ = check_interface_status(interface) + print(f" ✓ {interface}: {status}") + else: + print(f" ✗ {interface}: Failed") + + print("\nSetup complete!") + print("\nNext: Test motors with:") + print(f" lerobot-setup-can --mode=test --interfaces {','.join(interfaces)}") + + +def run_test(cfg: CANSetupConfig): + """Test motors on CAN interfaces.""" + print("=" * 50) + print("CAN Motor Test") + print("=" * 50) + print(f"Testing motors 0x{min(cfg.motor_ids):02X}-0x{max(cfg.motor_ids):02X}") + print(f"Mode: {'CAN FD' if cfg.use_fd else 'CAN 2.0'}") + print() + + interfaces = cfg.get_interfaces() + all_results = {} + for interface in interfaces: + all_results[interface] = test_interface(cfg, interface) + + total_found = sum(sum(1 for r in res.values() if r.get("found")) for res in all_results.values()) + + print("\n" + "=" * 50) + print("Summary") + print("=" * 50) + print(f"Total motors found: {total_found}") + + if total_found == 0: + print("\n⚠ No motors found! Check:") + print(" 1. Motors are powered (24V)") + print(" 2. CAN wiring (CANH, CANL, GND)") + print(" 3. Motor timeout parameter > 0 (use Damiao tools)") + print(" 4. 120Ω termination at both cable ends") + print(f" 5. Interface configured: lerobot-setup-can --mode=setup --interfaces {interfaces[0]}") + + +def run_speed(cfg: CANSetupConfig): + """Run speed tests on CAN interfaces.""" + print("=" * 50) + print("CAN Speed Test") + print("=" * 50) + + for interface in cfg.get_interfaces(): + speed_test(cfg, interface) + + +@draccus.wrap() +def setup_can(cfg: CANSetupConfig): + if not is_package_available("can"): + print("Error: python-can not installed. Install with: pip install python-can") + sys.exit(1) + + if cfg.mode == "setup": + run_setup(cfg) + elif cfg.mode == "test": + run_test(cfg) + elif cfg.mode == "speed": + run_speed(cfg) + else: + print(f"Unknown mode: {cfg.mode}") + print("Available modes: setup, test, speed") + sys.exit(1) + + +def main(): + setup_can() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index a499b96c7..c33a73589 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -73,6 +73,7 @@ _transformers_available = is_package_available("transformers") _peft_available = is_package_available("peft") _scipy_available = is_package_available("scipy") _reachy2_sdk_available = is_package_available("reachy2_sdk") +_can_available = is_package_available("python-can", "can") def make_device_from_device_class(config: ChoiceRegistry) -> Any: diff --git a/tests/motors/test_damiao.py b/tests/motors/test_damiao.py new file mode 100644 index 000000000..7ce1af34f --- /dev/null +++ b/tests/motors/test_damiao.py @@ -0,0 +1,66 @@ +"""Minimal test script for Damiao motor with ID 3.""" + +import pytest + +from lerobot.utils.import_utils import _can_available + +if not _can_available: + pytest.skip("python-can not available", allow_module_level=True) + +from lerobot.motors import Motor +from lerobot.motors.damiao import DamiaoMotorsBus + + +@pytest.mark.skip(reason="Requires physical Damiao motor and CAN interface") +def test_damiao_motor(): + motors = { + "joint_3": Motor( + id=0x03, + model="damiao", + norm_mode="degrees", + motor_type_str="dm4310", + recv_id=0x13, + ), + } + + bus = DamiaoMotorsBus(port="can0", motors=motors) + + try: + print("Connecting...") + bus.connect() + print("✓ Connected") + + print("Enabling torque...") + bus.enable_torque() + print("✓ Torque enabled") + + print("Reading all states...") + states = bus.sync_read_all_states() + print(f"✓ States: {states}") + + print("Reading position...") + positions = bus.sync_read("Present_Position") + print(f"✓ Position: {positions}") + + print("Testing MIT control batch...") + current_pos = states["joint_3"]["position"] + commands = {"joint_3": (10.0, 0.5, current_pos, 0.0, 0.0)} + bus._mit_control_batch(commands) + print("✓ MIT control batch sent") + + print("Disabling torque...") + bus.disable_torque() + print("✓ Torque disabled") + + print("Setting zero position...") + bus.set_zero_position() + print("✓ Zero position set") + + finally: + print("Disconnecting...") + bus.disconnect(disable_torque=True) + print("✓ Disconnected") + + +if __name__ == "__main__": + test_damiao_motor() From 0c0c171d3543fafe587e0ac1768dd033aaf12ebf Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Tue, 27 Jan 2026 13:33:45 +0100 Subject: [PATCH 064/111] Add robot images to docs (#2862) * Add robot images to docs * increase img size * remove img so100 --- docs/source/earthrover_mini_plus.mdx | 6 ++++++ docs/source/lekiwi.mdx | 6 ++++++ docs/source/so101.mdx | 13 +++++++++++++ 3 files changed, 25 insertions(+) diff --git a/docs/source/earthrover_mini_plus.mdx b/docs/source/earthrover_mini_plus.mdx index e3ffa6b32..d8083336a 100644 --- a/docs/source/earthrover_mini_plus.mdx +++ b/docs/source/earthrover_mini_plus.mdx @@ -1,5 +1,11 @@ # EarthRover Mini Plus +EarthRover Mini Plus + The EarthRover Mini Plus is a fully open source mobile robot that connects through the cloud using the Frodobots SDK. This lets you control the robot and record datasets for training AI models. ## What You Need diff --git a/docs/source/lekiwi.mdx b/docs/source/lekiwi.mdx index 511521580..b339225d8 100644 --- a/docs/source/lekiwi.mdx +++ b/docs/source/lekiwi.mdx @@ -1,5 +1,11 @@ # LeKiwi +LeKiwi + In the steps below, we explain how to assemble the LeKiwi mobile robot. ## Source the parts diff --git a/docs/source/so101.mdx b/docs/source/so101.mdx index cf882b373..7c9df588a 100644 --- a/docs/source/so101.mdx +++ b/docs/source/so101.mdx @@ -1,5 +1,18 @@ # SO-101 +
+ SO-101 + SO-101 +
+ In the steps below, we explain how to assemble our flagship robot, the SO-101. ## Source the parts From f6b1c39b785af0f2f78899f5de6e008f3295e594 Mon Sep 17 00:00:00 2001 From: Reece O'Mahoney <66252930+reeceomahoney@users.noreply.github.com> Date: Tue, 27 Jan 2026 14:31:53 +0000 Subject: [PATCH 065/111] docs: update libero (#2857) * update libero docs * Update docs/source/libero.mdx Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jade Choghari --------- Signed-off-by: Jade Choghari Co-authored-by: Jade Choghari Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- docs/source/libero.mdx | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/libero.mdx b/docs/source/libero.mdx index 3617f3b25..def974531 100644 --- a/docs/source/libero.mdx +++ b/docs/source/libero.mdx @@ -42,6 +42,7 @@ lerobot-eval \ ``` - `--env.task` picks the suite (`libero_object`, `libero_spatial`, etc.). +- `--env.task_ids` picks task ids to run (`[0]`, `[1,2,3]`, etc.). Omit this flag (or set it to `null`) to run all tasks in the suite. - `--eval.batch_size` controls how many environments run in parallel. - `--eval.n_episodes` sets how many episodes to run in total. From 736b43f3cfb5db2450fa787a45f645e1309caa00 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Wed, 28 Jan 2026 13:31:27 +0100 Subject: [PATCH 066/111] Fix(aggregate.py) Aggregation of datasets when sub-datasets are already a result of a previous merge (#2861) * Fix aggeregation of datasets when subdatasets are already a result of a previous merge * docstring * respond to copilot review + add regression test * Remove unnecessary int conversion for indicies --- src/lerobot/datasets/aggregate.py | 100 ++++++++++++++++++++++++------ tests/datasets/test_aggregate.py | 89 ++++++++++++++++++++++++++ 2 files changed, 171 insertions(+), 18 deletions(-) diff --git a/src/lerobot/datasets/aggregate.py b/src/lerobot/datasets/aggregate.py index 94ffe602e..7020545d2 100644 --- a/src/lerobot/datasets/aggregate.py +++ b/src/lerobot/datasets/aggregate.py @@ -116,6 +116,9 @@ def update_meta_data( Adjusts all indices and timestamps to account for previously aggregated data and videos in the destination dataset. + For data file indices, uses the 'src_to_dst' mapping from aggregate_data() + to correctly map source file indices to their destination locations. + Args: df: DataFrame containing the metadata to be updated. dst_meta: Destination dataset metadata. @@ -129,8 +132,50 @@ def update_meta_data( df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"] df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"] - df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"] - df["data/file_index"] = df["data/file_index"] + data_idx["file"] + + # Update data file indices using source-to-destination mapping + # This is critical for handling datasets that are already results of a merge + data_src_to_dst = data_idx.get("src_to_dst", {}) + if data_src_to_dst: + # Store original indices for lookup + df["_orig_data_chunk"] = df["data/chunk_index"].copy() + df["_orig_data_file"] = df["data/file_index"].copy() + + # Vectorized mapping from (src_chunk, src_file) to (dst_chunk, dst_file) + # This is much faster than per-row iteration for large metadata tables + mapping_index = pd.MultiIndex.from_tuples( + list(data_src_to_dst.keys()), + names=["chunk_index", "file_index"], + ) + mapping_values = list(data_src_to_dst.values()) + mapping_df = pd.DataFrame( + mapping_values, + index=mapping_index, + columns=["dst_chunk", "dst_file"], + ) + + # Construct a MultiIndex for each row based on original data indices + row_index = pd.MultiIndex.from_arrays( + [df["_orig_data_chunk"], df["_orig_data_file"]], + names=["chunk_index", "file_index"], + ) + + # Align mapping to rows; missing keys fall back to the default destination + reindexed = mapping_df.reindex(row_index) + reindexed[["dst_chunk", "dst_file"]] = reindexed[["dst_chunk", "dst_file"]].fillna( + {"dst_chunk": data_idx["chunk"], "dst_file": data_idx["file"]} + ) + + # Assign mapped destination indices back to the DataFrame + df["data/chunk_index"] = reindexed["dst_chunk"].to_numpy() + df["data/file_index"] = reindexed["dst_file"].to_numpy() + + # Clean up temporary columns + df = df.drop(columns=["_orig_data_chunk", "_orig_data_file"]) + else: + # Fallback to simple offset (backward compatibility for single-file sources) + df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"] + df["data/file_index"] = df["data/file_index"] + data_idx["file"] for key, video_idx in videos_idx.items(): # Store original video file indices before updating orig_chunk_col = f"videos/{key}/chunk_index" @@ -146,8 +191,7 @@ def update_meta_data( if src_to_dst: # Map each episode to its correct destination file and apply offset for idx in df.index: - # Convert to Python int to avoid numpy type mismatch in dict lookup - src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"])) + src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"]) # Get destination chunk/file for this source file dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"])) @@ -163,8 +207,7 @@ def update_meta_data( df[orig_chunk_col] = video_idx["chunk"] df[orig_file_col] = video_idx["file"] for idx in df.index: - # Convert to Python int to avoid numpy type mismatch in dict lookup - src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"])) + src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"]) offset = src_to_offset.get(src_key, 0) df.at[idx, f"videos/{key}/from_timestamp"] += offset df.at[idx, f"videos/{key}/to_timestamp"] += offset @@ -262,6 +305,10 @@ def aggregate_datasets( meta_idx = aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx) + # Clear the src_to_dst mapping after processing each source dataset + # to avoid interference between different source datasets + data_idx.pop("src_to_dst", None) + dst_meta.info["total_episodes"] += src_meta.total_episodes dst_meta.info["total_frames"] += src_meta.total_frames @@ -312,10 +359,6 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu dst_file_durations = video_idx["dst_file_durations"] for src_chunk_idx, src_file_idx in unique_chunk_file_pairs: - # Convert to Python int to ensure consistent dict keys - src_chunk_idx = int(src_chunk_idx) - src_file_idx = int(src_file_idx) - src_path = src_meta.root / DEFAULT_VIDEO_PATH.format( video_key=key, chunk_index=src_chunk_idx, @@ -388,10 +431,16 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si Reads source data files, updates indices to match the aggregated dataset, and writes them to the destination with proper file rotation. + Tracks a `src_to_dst` mapping from source (chunk, file) to destination (chunk, file) + which is critical for correctly updating episode metadata when source datasets + have multiple data files (e.g., from a previous merge operation). + Args: src_meta: Source dataset metadata. dst_meta: Destination dataset metadata. data_idx: Dictionary tracking data chunk and file indices. + data_files_size_in_mb: Maximum size for data files in MB. + chunk_size: Maximum number of files per chunk. Returns: dict: Updated data_idx with current chunk and file indices. @@ -409,6 +458,10 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si # retrieve features schema for proper image typing in parquet hf_features = get_hf_features_from_features(dst_meta.features) if contains_images else None + # Track source to destination file mapping for metadata update + # This is critical for handling datasets that are already results of a merge + src_to_dst: dict[tuple[int, int], tuple[int, int]] = {} + for src_chunk_idx, src_file_idx in unique_chunk_file_ids: src_path = src_meta.root / DEFAULT_DATA_PATH.format( chunk_index=src_chunk_idx, file_index=src_file_idx @@ -421,7 +474,9 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si df = pd.read_parquet(src_path) df = update_data_df(df, src_meta, dst_meta) - data_idx = append_or_create_parquet_file( + # Write data and get the actual destination file it was written to + # This avoids duplicating the rotation logic here + data_idx, (dst_chunk, dst_file) = append_or_create_parquet_file( df, src_path, data_idx, @@ -433,6 +488,12 @@ def aggregate_data(src_meta, dst_meta, data_idx, data_files_size_in_mb, chunk_si hf_features=hf_features, ) + # Record the mapping from source to actual destination + src_to_dst[(src_chunk_idx, src_file_idx)] = (dst_chunk, dst_file) + + # Add the mapping to data_idx for use in metadata update + data_idx["src_to_dst"] = src_to_dst + return data_idx @@ -473,7 +534,7 @@ def aggregate_metadata(src_meta, dst_meta, meta_idx, data_idx, videos_idx): videos_idx, ) - meta_idx = append_or_create_parquet_file( + meta_idx, _ = append_or_create_parquet_file( df, src_path, meta_idx, @@ -501,7 +562,7 @@ def append_or_create_parquet_file( contains_images: bool = False, aggr_root: Path = None, hf_features: datasets.Features | None = None, -): +) -> tuple[dict[str, int], tuple[int, int]]: """Appends data to an existing parquet file or creates a new one based on size constraints. Manages file rotation when size limits are exceeded to prevent individual files @@ -519,9 +580,11 @@ def append_or_create_parquet_file( hf_features: Optional HuggingFace Features schema for proper image typing. Returns: - dict: Updated index dictionary with current chunk and file indices. + tuple: (updated_idx, (dst_chunk, dst_file)) where updated_idx is the index dict + and (dst_chunk, dst_file) is the actual destination file the data was written to. """ - dst_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"]) + dst_chunk, dst_file = idx["chunk"], idx["file"] + dst_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file) if not dst_path.exists(): dst_path.parent.mkdir(parents=True, exist_ok=True) @@ -529,14 +592,15 @@ def append_or_create_parquet_file( to_parquet_with_hf_images(df, dst_path, features=hf_features) else: df.to_parquet(dst_path) - return idx + return idx, (dst_chunk, dst_file) src_size = get_parquet_file_size_in_mb(src_path) dst_size = get_parquet_file_size_in_mb(dst_path) if dst_size + src_size >= max_mb: idx["chunk"], idx["file"] = update_chunk_file_indices(idx["chunk"], idx["file"], chunk_size) - new_path = aggr_root / default_path.format(chunk_index=idx["chunk"], file_index=idx["file"]) + dst_chunk, dst_file = idx["chunk"], idx["file"] + new_path = aggr_root / default_path.format(chunk_index=dst_chunk, file_index=dst_file) new_path.parent.mkdir(parents=True, exist_ok=True) final_df = df target_path = new_path @@ -555,7 +619,7 @@ def append_or_create_parquet_file( else: final_df.to_parquet(target_path) - return idx + return idx, (dst_chunk, dst_file) def finalize_aggregation(aggr_meta, all_metadata): diff --git a/tests/datasets/test_aggregate.py b/tests/datasets/test_aggregate.py index 031c29d60..3609bac24 100644 --- a/tests/datasets/test_aggregate.py +++ b/tests/datasets/test_aggregate.py @@ -525,3 +525,92 @@ def test_aggregate_image_datasets(tmp_path, lerobot_dataset_factory): assert img.shape[0] == 3, f"Image {image_key} should have 3 channels" assert_dataset_iteration_works(aggr_ds) + + +def test_aggregate_already_merged_dataset(tmp_path, lerobot_dataset_factory): + """Regression test for aggregating a dataset that is itself a result of a previous merge. + + This test reproduces the bug where merging datasets with multiple parquet files + (e.g., from a previous merge with file rotation) would cause FileNotFoundError + because metadata file indices were incorrectly preserved instead of being mapped + to their actual destination files. + + The fix adds src_to_dst tracking in aggregate_data() to correctly map source + file indices to destination file indices. + """ + # Step 1: Create datasets A and B + ds_a = lerobot_dataset_factory( + root=tmp_path / "ds_a", + repo_id=f"{DUMMY_REPO_ID}_a", + total_episodes=4, + total_frames=200, + ) + ds_b = lerobot_dataset_factory( + root=tmp_path / "ds_b", + repo_id=f"{DUMMY_REPO_ID}_b", + total_episodes=4, + total_frames=200, + ) + + # Step 2: Merge A+B into AB with small file size to force multiple files + aggregate_datasets( + repo_ids=[ds_a.repo_id, ds_b.repo_id], + roots=[ds_a.root, ds_b.root], + aggr_repo_id=f"{DUMMY_REPO_ID}_ab", + aggr_root=tmp_path / "ds_ab", + data_files_size_in_mb=0.01, # Force file rotation + ) + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "ds_ab") + ds_ab = LeRobotDataset(f"{DUMMY_REPO_ID}_ab", root=tmp_path / "ds_ab") + + # Verify AB has multiple data files (file rotation occurred) + ab_data_files = list((tmp_path / "ds_ab" / "data").rglob("*.parquet")) + assert len(ab_data_files) > 1, "First merge should create multiple parquet files" + + # Step 3: Create dataset C + ds_c = lerobot_dataset_factory( + root=tmp_path / "ds_c", + repo_id=f"{DUMMY_REPO_ID}_c", + total_episodes=2, + total_frames=100, + ) + + # Step 4: Merge AB+C into final - THIS IS WHERE THE BUG OCCURRED + aggregate_datasets( + repo_ids=[ds_ab.repo_id, ds_c.repo_id], + roots=[ds_ab.root, ds_c.root], + aggr_repo_id=f"{DUMMY_REPO_ID}_abc", + aggr_root=tmp_path / "ds_abc", + ) + + with ( + patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version, + patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download, + ): + mock_get_safe_version.return_value = "v3.0" + mock_snapshot_download.return_value = str(tmp_path / "ds_abc") + ds_abc = LeRobotDataset(f"{DUMMY_REPO_ID}_abc", root=tmp_path / "ds_abc") + + # Step 5: Verify all data files referenced in metadata actually exist + for ep_idx in range(ds_abc.num_episodes): + data_file_path = ds_abc.root / ds_abc.meta.get_data_file_path(ep_idx) + assert data_file_path.exists(), ( + f"Episode {ep_idx} references non-existent file: {data_file_path}\n" + "This indicates the src_to_dst mapping fix is not working correctly." + ) + + # Step 6: Verify we can iterate through the entire dataset without FileNotFoundError + expected_episodes = ds_a.num_episodes + ds_b.num_episodes + ds_c.num_episodes + expected_frames = ds_a.num_frames + ds_b.num_frames + ds_c.num_frames + + assert ds_abc.num_episodes == expected_episodes + assert ds_abc.num_frames == expected_frames + + # This would raise FileNotFoundError before the fix + assert_dataset_iteration_works(ds_abc) From bf337e716da18054e463003fa37f47df2aa9bfe3 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 28 Jan 2026 14:28:51 +0100 Subject: [PATCH 067/111] feat(robots): add OpenArm robot & teleoperator (#2795) * fix(motors): cleanup imports + fix signatures * feat(motors): add damiao canbus + multiple fixes * fix(motors): address comments -> last_state + different gains + sleep * refactor(motors): reduce duplicated code + adressed some comments in the PR * chore(motors): better timeouts * tests(motors): damiao test and imports * chore(deps): fix space * feat(robot): add openarm leader Co-authored-by: Pepijn * feat(robot): add openarm follower Co-authored-by: Pepijn * refactor(robot): remove mechanical compensations and double arm assumption + rename * chore(robots): remove left arm references * refactor(teleop): multiple improvements to leader * refactor(teleop): multiple improvements to leader * feat(robots): add open arm to util CLI * chore(robot): add alias openarm * Apply suggestions from code review Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Signed-off-by: Steven Palma * chore(motors): remove normalization tables damiao * fix(motors): imports and signatures * feat(motors): add motor_type_str + recv_id to motor class and _get_motor_recv_id raises if no motor_obj.recv_id * chore(motors): remove normalize from base motor class and damaio * tests(motors): remove bad tests (to be replaced) * chore(motors): updated import check * fix(robots): open arm mirrored config for joint limits * chore(motors): update position_kd gain values * chore(robots): set to 0 if openarm is calibrated at connect time * chore(robots): remove macos in open arm as can doesn't support it * chore(robots): update for motor_type_str in Motor class * chore(robots): no default value for can port in open arms * use constant for kp and kd range and check responses in mit_control_batch() * Add docs on setting up canbus and use damiao otor bus, also add lerobot_setup_can.py and log if there is not response from a write command * precommit format * supress bandit as these are intentional cli commands * fix setup-can * add test * skip test in ci * nit precommit * update doc example * dont import can for tests * remove comment * Add openarms docs * format * update purchase link * can to none if nit availabl;e * add canfd option in bus * make handshake logic similar to lerobot-can * type hint * type check * add temp teleop test * remove script * mock class * ignore linter --------- Signed-off-by: Steven Palma Co-authored-by: Pepijn Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --- docs/source/_toctree.yml | 2 + docs/source/openarm.mdx | 258 +++++++++++++ pyproject.toml | 1 + src/lerobot/motors/damiao/damiao.py | 51 ++- src/lerobot/processor/hil_processor.py | 12 +- .../robots/openarm_follower/__init__.py | 20 + .../config_openarm_follower.py | 117 ++++++ .../openarm_follower/openarm_follower.py | 348 ++++++++++++++++++ src/lerobot/robots/utils.py | 4 + src/lerobot/scripts/lerobot_calibrate.py | 2 + .../scripts/lerobot_find_joint_limits.py | 2 + src/lerobot/scripts/lerobot_record.py | 2 + src/lerobot/scripts/lerobot_replay.py | 1 + src/lerobot/scripts/lerobot_teleoperate.py | 2 + .../teleoperators/openarm_leader/__init__.py | 20 + .../openarm_leader/config_openarm_leader.py | 70 ++++ .../openarm_leader/openarm_leader.py | 225 +++++++++++ src/lerobot/teleoperators/utils.py | 14 +- 18 files changed, 1129 insertions(+), 22 deletions(-) create mode 100644 docs/source/openarm.mdx create mode 100644 src/lerobot/robots/openarm_follower/__init__.py create mode 100644 src/lerobot/robots/openarm_follower/config_openarm_follower.py create mode 100644 src/lerobot/robots/openarm_follower/openarm_follower.py create mode 100644 src/lerobot/teleoperators/openarm_leader/__init__.py create mode 100644 src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py create mode 100644 src/lerobot/teleoperators/openarm_leader/openarm_leader.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index f86dd11c7..eb97117af 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -101,6 +101,8 @@ title: Earth Rover Mini - local: omx title: OMX + - local: openarm + title: OpenArm title: "Robots" - sections: - local: phone_teleop diff --git a/docs/source/openarm.mdx b/docs/source/openarm.mdx new file mode 100644 index 000000000..661808749 --- /dev/null +++ b/docs/source/openarm.mdx @@ -0,0 +1,258 @@ +# OpenArm + +[OpenArm](https://openarm.dev) is an open-source 7DOF humanoid arm designed for physical AI research and deployment. + +To get your OpenArm, assembled or DIY, and join the global community, browse verified and certified manufacturers worldwide at [openarm.dev](https://openarm.dev). + +## What's Unique? + +- **Human-Scale Design**: OpenArm is designed with human-like proportions, scaled for a person around 160-165cm tall. This provides an optimal balance between practical reach and manageable inertia for safe, responsive operation. + +- **Safety-First Architecture**: Built with QDD backdrivable motors and high compliance, OpenArm prioritizes safe human-robot interaction while maintaining practical payload capabilities (6.0kg peak / 4.1kg nominal) for real-world tasks. + +- **Built for Durability**: Critical structural components use aluminum and stainless steel construction, ensuring robust performance for repetitive data collection and continuous research use. + +- **Fully Accessible & Buildable**: Every component, from CNC parts and 3D-printed casings to electrical wiring is designed to be purchasable and buildable by individual researchers and labs, with complete fabrication data provided. + +- **Practical & Affordable**: At $6,500 USD for a complete bimanual system, OpenArm delivers research-grade capabilities at a fraction of traditional humanoid robot costs. + +## Platform Requirements + + + **Linux Only**: OpenArm currently only works on Linux. The CAN bus USB adapter + does not have macOS drivers and has not been tested on Windows. + + +## Safety Guide + +Before operating OpenArm, please read the [official safety guide](https://docs.openarm.dev/getting-started/safety-guide). Key points: + +- **Secure installation**: Fasten the arm to a flat, stable surface with screws or clamps +- **Safe distance**: Keep body parts and objects outside the range of motion during operation +- **Protective equipment**: Always wear safety goggles; use additional PPE as needed +- **Payload limits**: Do not exceed specified payload limits (6.0kg peak / 4.1kg nominal per arm) +- **Emergency stop**: Know the location and operation of the emergency stop device +- **Regular inspection**: Check for loose screws, damaged mechanical limits, unusual noises, and wiring damage + +## Hardware Setup + +Follow the official [OpenArm hardware documentation](https://docs.openarm.dev) for: + +- Bill of materials and sourcing +- 3D printing instructions +- Mechanical assembly +- Electrical wiring + +The hardware repositories are available at [github.com/enactic/openarm](https://github.com/enactic/openarm). + +## CAN Bus Setup + +OpenArm uses CAN bus communication with Damiao motors. Once you have the CAN bus USB adapter plugged into your Linux PC, follow the [Damiao Motors and CAN Bus guide](./damiao) to configure the interface. + +Quick setup: + +```bash +# Setup CAN interfaces +lerobot-setup-can --mode=setup --interfaces=can0,can1 + +# Test motor communication +lerobot-setup-can --mode=test --interfaces=can0,can1 +``` + +## Install LeRobot 🤗 + +Follow our [Installation Guide](./installation), then install the Damiao motor support: + +```bash +pip install -e ".[damiao]" +``` + +## Usage + +### Follower Arm (Robot) + + + + +```bash +lerobot-calibrate \ + --robot.type=openarm_follower \ + --robot.port=can0 \ + --robot.side=right \ + --robot.id=my_openarm_follower +``` + + + + +```python +from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig + +config = OpenArmFollowerConfig( + port="can0", + side="right", # or "left" for left arm + id="my_openarm_follower", +) + +follower = OpenArmFollower(config) +follower.connect() + +# Read current state +obs = follower.get_observation() +print(obs) + +# Send action (position in degrees) +action = { + "joint_1.pos": 0.0, + "joint_2.pos": 0.0, + "joint_3.pos": 0.0, + "joint_4.pos": 45.0, + "joint_5.pos": 0.0, + "joint_6.pos": 0.0, + "joint_7.pos": 0.0, + "gripper.pos": 0.0, +} +follower.send_action(action) + +follower.disconnect() +``` + + + + +### Leader Arm (Teleoperator) + +The leader arm is used for teleoperation - manually moving it to control the follower arm. + + + + +```bash +lerobot-calibrate \ + --teleop.type=openarm_leader \ + --teleop.port=can1 \ + --teleop.id=my_openarm_leader +``` + + + + +```python +from lerobot.teleoperators.openarm_leader import OpenArmLeader, OpenArmLeaderConfig + +config = OpenArmLeaderConfig( + port="can1", + id="my_openarm_leader", + manual_control=True, # Disable torque for manual movement +) + +leader = OpenArmLeader(config) +leader.connect() + +# Read current position (as action to send to follower) +action = leader.get_action() +print(action) + +leader.disconnect() +``` + + + + +### Teleoperation + +To teleoperate OpenArm with leader-follower control: + +```bash +lerobot-teleoperate \ + --robot.type=openarm_follower \ + --robot.port=can0 \ + --robot.side=right \ + --robot.id=my_follower \ + --teleop.type=openarm_leader \ + --teleop.port=can1 \ + --teleop.id=my_leader +``` + +### Recording Data + +To record a dataset during teleoperation: + +```bash +lerobot-record \ + --robot.type=openarm_follower \ + --robot.port=can0 \ + --robot.side=right \ + --robot.id=my_follower \ + --teleop.type=openarm_leader \ + --teleop.port=can1 \ + --teleop.id=my_leader \ + --repo-id=my_hf_username/my_openarm_dataset \ + --fps=30 \ + --num-episodes=10 +``` + +## Configuration Options + +### Follower Configuration + +| Parameter | Default | Description | +| --------------------- | --------- | ---------------------------------------------------------- | +| `port` | - | CAN interface (e.g., `can0`) | +| `side` | `None` | Arm side: `"left"`, `"right"`, or `None` for custom limits | +| `use_can_fd` | `True` | Enable CAN FD for higher data rates | +| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) | +| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) | +| `max_relative_target` | `None` | Safety limit for relative target positions | +| `position_kp` | Per-joint | Position control proportional gains | +| `position_kd` | Per-joint | Position control derivative gains | + +### Leader Configuration + +| Parameter | Default | Description | +| ------------------ | --------- | ----------------------------------- | +| `port` | - | CAN interface (e.g., `can1`) | +| `manual_control` | `True` | Disable torque for manual movement | +| `use_can_fd` | `True` | Enable CAN FD for higher data rates | +| `can_bitrate` | `1000000` | Nominal bitrate (1 Mbps) | +| `can_data_bitrate` | `5000000` | CAN FD data bitrate (5 Mbps) | + +## Motor Configuration + +OpenArm uses Damiao motors with the following default configuration: + +| Joint | Motor Type | Send ID | Recv ID | +| --------------------------- | ---------- | ------- | ------- | +| joint_1 (Shoulder pan) | DM8009 | 0x01 | 0x11 | +| joint_2 (Shoulder lift) | DM8009 | 0x02 | 0x12 | +| joint_3 (Shoulder rotation) | DM4340 | 0x03 | 0x13 | +| joint_4 (Elbow flex) | DM4340 | 0x04 | 0x14 | +| joint_5 (Wrist roll) | DM4310 | 0x05 | 0x15 | +| joint_6 (Wrist pitch) | DM4310 | 0x06 | 0x16 | +| joint_7 (Wrist rotation) | DM4310 | 0x07 | 0x17 | +| gripper | DM4310 | 0x08 | 0x18 | + +## Troubleshooting + +### No Response from Motors + +1. Check power supply connections +2. Verify CAN wiring (CAN-H, CAN-L, GND) +3. Run diagnostics: `lerobot-setup-can --mode=test --interfaces=can0` +4. See the [Damiao troubleshooting guide](./damiao#troubleshooting) for more details + +### CAN Interface Not Found + +Ensure the CAN interface is configured: + +```bash +ip link show can0 +``` + +## Resources + +- [OpenArm Website](https://openarm.dev) +- [OpenArm Documentation](https://docs.openarm.dev) +- [OpenArm GitHub](https://github.com/enactic/openarm) +- [Safety Guide](https://docs.openarm.dev/getting-started/safety-guide) +- [Damiao Motors and CAN Bus](./damiao) diff --git a/pyproject.toml b/pyproject.toml index 27126f855..ea2dfb4a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -105,6 +105,7 @@ dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"] damiao = ["python-can>=4.2.0,<5.0.0"] # Robots +openarms = ["lerobot[damiao]"] gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"] hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"] lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"] diff --git a/src/lerobot/motors/damiao/damiao.py b/src/lerobot/motors/damiao/damiao.py index dd0213fc3..c79f8d17e 100644 --- a/src/lerobot/motors/damiao/damiao.py +++ b/src/lerobot/motors/damiao/damiao.py @@ -28,8 +28,11 @@ from lerobot.utils.import_utils import _can_available if TYPE_CHECKING or _can_available: import can else: - can.Message = object - can.interface = None + + class can: # noqa: N801 + Message = object + interface = None + import numpy as np @@ -206,11 +209,31 @@ class DamiaoMotorsBus(MotorsBusBase): Raises ConnectionError if any motor fails to respond. """ logger.info("Starting handshake with motors...") - missing_motors = [] + # Drain any pending messages + while self.canbus.recv(timeout=0.01): + pass + + missing_motors = [] for motor_name in self.motors: - msg = self._refresh_motor(motor_name) - if msg is None: + motor_id = self._get_motor_id(motor_name) + recv_id = self._get_motor_recv_id(motor_name) + + # Send enable command + data = [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, CAN_CMD_ENABLE] + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) + self.canbus.send(msg) + + # Wait for response with longer timeout + response = None + start_time = time.time() + while time.time() - start_time < 0.1: + response = self.canbus.recv(timeout=0.1) + if response and response.arbitration_id == recv_id: + break + response = None + + if response is None: missing_motors.append(motor_name) else: self._process_response(motor_name, msg) @@ -259,7 +282,7 @@ class DamiaoMotorsBus(MotorsBusBase): motor_name = self._get_motor_name(motor) recv_id = self._get_motor_recv_id(motor) data = [0xFF] * 7 + [command_byte] - msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) self.canbus.send(msg) if msg := self._recv_motor_response(expected_recv_id=recv_id): self._process_response(motor_name, msg) @@ -317,7 +340,7 @@ class DamiaoMotorsBus(MotorsBusBase): motor_id = self._get_motor_id(motor) recv_id = self._get_motor_recv_id(motor) data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0] - msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False) + msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd) self.canbus.send(msg) return self._recv_motor_response(expected_recv_id=recv_id) @@ -439,7 +462,7 @@ class DamiaoMotorsBus(MotorsBusBase): motor_type = self._motor_types[motor_name] data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque) - msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) self.canbus.send(msg) recv_id = self._get_motor_recv_id(motor) @@ -472,7 +495,7 @@ class DamiaoMotorsBus(MotorsBusBase): motor_type = self._motor_types[motor_name] data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque) - msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) self.canbus.send(msg) recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name @@ -637,10 +660,10 @@ class DamiaoMotorsBus(MotorsBusBase): for motor in motors: motor_id = self._get_motor_id(motor) data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0] - msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False) + msg = can.Message( + arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd + ) self.canbus.send(msg) - # Small delay to reduce bus congestion if necessary, though removed in sync_read previously - # precise_sleep(PRECISE_SLEEP_SEC) # Collect responses expected_recv_ids = [self._get_motor_recv_id(m) for m in motors] @@ -676,7 +699,9 @@ class DamiaoMotorsBus(MotorsBusBase): kd = self._gains[motor]["kd"] data = self._encode_mit_packet(motor_type, kp, kd, float(value_degrees), 0.0, 0.0) - msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + msg = can.Message( + arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd + ) self.canbus.send(msg) precise_sleep(PRECISE_TIMEOUT_SEC) diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index f0dbac9c3..6d44ed8cb 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -18,16 +18,18 @@ import math import time from dataclasses import dataclass -from typing import Any, Protocol, TypeVar, runtime_checkable +from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable import numpy as np import torch import torchvision.transforms.functional as F # noqa: N812 from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.teleoperators.teleoperator import Teleoperator from lerobot.teleoperators.utils import TeleopEvents +if TYPE_CHECKING: + from lerobot.teleoperators.teleoperator import Teleoperator + from .core import EnvTransition, PolicyAction, TransitionKey from .pipeline import ( ComplementaryDataProcessorStep, @@ -69,10 +71,10 @@ class HasTeleopEvents(Protocol): # Type variable constrained to Teleoperator subclasses that also implement events -TeleopWithEvents = TypeVar("TeleopWithEvents", bound=Teleoperator) +TeleopWithEvents = TypeVar("TeleopWithEvents", bound="Teleoperator") -def _check_teleop_with_events(teleop: Teleoperator) -> None: +def _check_teleop_with_events(teleop: "Teleoperator") -> None: """ Runtime check that a teleoperator implements the `HasTeleopEvents` protocol. @@ -103,7 +105,7 @@ class AddTeleopActionAsComplimentaryDataStep(ComplementaryDataProcessorStep): teleop_device: The teleoperator instance to get the action from. """ - teleop_device: Teleoperator + teleop_device: "Teleoperator" def complementary_data(self, complementary_data: dict) -> dict: """ diff --git a/src/lerobot/robots/openarm_follower/__init__.py b/src/lerobot/robots/openarm_follower/__init__.py new file mode 100644 index 000000000..1eb0d9fc7 --- /dev/null +++ b/src/lerobot/robots/openarm_follower/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_openarm_follower import OpenArmFollowerConfig +from .openarm_follower import OpenArmFollower + +__all__ = ["OpenArmFollower", "OpenArmFollowerConfig"] diff --git a/src/lerobot/robots/openarm_follower/config_openarm_follower.py b/src/lerobot/robots/openarm_follower/config_openarm_follower.py new file mode 100644 index 000000000..af95b6395 --- /dev/null +++ b/src/lerobot/robots/openarm_follower/config_openarm_follower.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.cameras import CameraConfig + +from ..config import RobotConfig + +LEFT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = { + "joint_1": (-75.0, 75.0), + "joint_2": (-90.0, 9.0), + "joint_3": (-85.0, 85.0), + "joint_4": (0.0, 135.0), + "joint_5": (-85.0, 85.0), + "joint_6": (-40.0, 40.0), + "joint_7": (-80.0, 80.0), + "gripper": (-65.0, 0.0), +} + +RIGHT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = { + "joint_1": (-75.0, 75.0), + "joint_2": (-9.0, 90.0), + "joint_3": (-85.0, 85.0), + "joint_4": (0.0, 135.0), + "joint_5": (-85.0, 85.0), + "joint_6": (-40.0, 40.0), + "joint_7": (-80.0, 80.0), + "gripper": (-65.0, 0.0), +} + + +@RobotConfig.register_subclass("openarm_follower") +@dataclass +class OpenArmFollowerConfig(RobotConfig): + """Configuration for the OpenArms follower robot with Damiao motors.""" + + # CAN interfaces - one per arm + # arm CAN interface (e.g., "can1") + # Linux: "can0", "can1", etc. + port: str + + # side of the arm: "left" or "right". If "None" default values will be used + side: str | None = None + + # CAN interface type: "socketcan" (Linux), "slcan" (serial), or "auto" (auto-detect) + can_interface: str = "socketcan" + + # CAN FD settings (OpenArms uses CAN FD by default) + use_can_fd: bool = True + can_bitrate: int = 1000000 # Nominal bitrate (1 Mbps) + can_data_bitrate: int = 5000000 # Data bitrate for CAN FD (5 Mbps) + + # Whether to disable torque when disconnecting + disable_torque_on_disconnect: bool = True + + # Safety limit for relative target positions + # Set to a positive scalar for all motors, or a dict mapping motor names to limits + max_relative_target: float | dict[str, float] | None = None + + # Camera configurations + cameras: dict[str, CameraConfig] = field(default_factory=dict) + + # Motor configuration for OpenArms (7 DOF per arm) + # Maps motor names to (send_can_id, recv_can_id, motor_type) + # Based on: https://docs.openarm.dev/software/setup/configure-test + # OpenArms uses 4 types of motors: + # - DM8009 (DM-J8009P-2EC) for shoulders (high torque) + # - DM4340P and DM4340 for shoulder rotation and elbow + # - DM4310 (DM-J4310-2EC V1.1) for wrist and gripper + motor_config: dict[str, tuple[int, int, str]] = field( + default_factory=lambda: { + "joint_1": (0x01, 0x11, "dm8009"), # J1 - Shoulder pan (DM8009) + "joint_2": (0x02, 0x12, "dm8009"), # J2 - Shoulder lift (DM8009) + "joint_3": (0x03, 0x13, "dm4340"), # J3 - Shoulder rotation (DM4340) + "joint_4": (0x04, 0x14, "dm4340"), # J4 - Elbow flex (DM4340) + "joint_5": (0x05, 0x15, "dm4310"), # J5 - Wrist roll (DM4310) + "joint_6": (0x06, 0x16, "dm4310"), # J6 - Wrist pitch (DM4310) + "joint_7": (0x07, 0x17, "dm4310"), # J7 - Wrist rotation (DM4310) + "gripper": (0x08, 0x18, "dm4310"), # J8 - Gripper (DM4310) + } + ) + + # MIT control parameters for position control (used in send_action) + # List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper] + position_kp: list[float] = field( + default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0, 25.0] + ) + position_kd: list[float] = field(default_factory=lambda: [5.0, 5.0, 3.0, 5.0, 0.3, 0.3, 0.3, 0.3]) + + # Values for joint limits. Can be overridden via CLI (for custom values) or by setting config.side to either 'left' or 'right'. + # If config.side is left set to None and no CLI values are passed, the default joint limit values are small for safety. + joint_limits: dict[str, tuple[float, float]] = field( + default_factory=lambda: { + "joint_1": (-5.0, 5.0), + "joint_2": (-5.0, 5.0), + "joint_3": (-5.0, 5.0), + "joint_4": (0.0, 5.0), + "joint_5": (-5.0, 5.0), + "joint_6": (-5.0, 5.0), + "joint_7": (-5.0, 5.0), + "gripper": (-5.0, 0.0), + } + ) diff --git a/src/lerobot/robots/openarm_follower/openarm_follower.py b/src/lerobot/robots/openarm_follower/openarm_follower.py new file mode 100644 index 000000000..c221afd10 --- /dev/null +++ b/src/lerobot/robots/openarm_follower/openarm_follower.py @@ -0,0 +1,348 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from functools import cached_property +from typing import Any + +from lerobot.cameras.utils import make_cameras_from_configs +from lerobot.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.motors.damiao import DamiaoMotorsBus +from lerobot.processor import RobotAction, RobotObservation +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +from ..robot import Robot +from ..utils import ensure_safe_goal_position +from .config_openarm_follower import ( + LEFT_DEFAULT_JOINTS_LIMITS, + RIGHT_DEFAULT_JOINTS_LIMITS, + OpenArmFollowerConfig, +) + +logger = logging.getLogger(__name__) + + +class OpenArmFollower(Robot): + """ + OpenArms Follower Robot which uses CAN bus communication to control 7 DOF arm with a gripper. + The arm uses Damiao motors in MIT control mode. + """ + + config_class = OpenArmFollowerConfig + name = "openarm_follower" + + def __init__(self, config: OpenArmFollowerConfig): + super().__init__(config) + self.config = config + + # Arm motors + motors: dict[str, Motor] = {} + for motor_name, (send_id, recv_id, motor_type_str) in config.motor_config.items(): + motor = Motor( + send_id, motor_type_str, MotorNormMode.DEGREES + ) # Always use degrees for Damiao motors + motor.recv_id = recv_id + motor.motor_type_str = motor_type_str + motors[motor_name] = motor + + self.bus = DamiaoMotorsBus( + port=self.config.port, + motors=motors, + calibration=self.calibration, + can_interface=self.config.can_interface, + use_can_fd=self.config.use_can_fd, + bitrate=self.config.can_bitrate, + data_bitrate=self.config.can_data_bitrate if self.config.use_can_fd else None, + ) + + if config.side is not None: + if config.side == "left": + config.joint_limits = LEFT_DEFAULT_JOINTS_LIMITS + elif config.side == "right": + config.joint_limits = RIGHT_DEFAULT_JOINTS_LIMITS + else: + raise ValueError( + "config.side must be either 'left', 'right' (for default values) or 'None' (for CLI values)" + ) + else: + logger.info( + "Set config.side to either 'left' or 'right' to use pre-configured values for joint limits." + ) + logger.info(f"Values used for joint limits: {config.joint_limits}.") + + # Initialize cameras + self.cameras = make_cameras_from_configs(config.cameras) + + @property + def _motors_ft(self) -> dict[str, type]: + """Motor features for observation and action spaces.""" + features: dict[str, type] = {} + for motor in self.bus.motors: + features[f"{motor}.pos"] = float + features[f"{motor}.vel"] = float # Add this + features[f"{motor}.torque"] = float # Add this + return features + + @property + def _cameras_ft(self) -> dict[str, tuple]: + """Camera features for observation space.""" + return { + cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras + } + + @cached_property + def observation_features(self) -> dict[str, type | tuple]: + """Combined observation features from motors and cameras.""" + return {**self._motors_ft, **self._cameras_ft} + + @cached_property + def action_features(self) -> dict[str, type]: + """Action features.""" + return self._motors_ft + + @property + def is_connected(self) -> bool: + """Check if robot is connected.""" + return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + + def connect(self, calibrate: bool = True) -> None: + """ + Connect to the robot and optionally calibrate. + + We assume that at connection time, the arms are in a safe rest position, + and torque can be safely disabled to run calibration if needed. + """ + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + # Connect to CAN bus + logger.info(f"Connecting arm on {self.config.port}...") + self.bus.connect() + + # Run calibration if needed + if not self.is_calibrated and calibrate: + logger.info( + "Mismatch between calibration values in the motor and the calibration file or no calibration file found" + ) + self.calibrate() + + for cam in self.cameras.values(): + cam.connect() + + self.configure() + + if self.is_calibrated: + self.bus.set_zero_position() + + self.bus.enable_torque() + + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + """Check if robot is calibrated.""" + return self.bus.is_calibrated + + def calibrate(self) -> None: + """ + Run calibration procedure for OpenArms robot. + + The calibration procedure: + 1. Disable torque + 2. Ask user to position arms in hanging position with grippers closed + 3. Set this as zero position + 4. Record range of motion for each joint + 5. Save calibration + """ + if self.calibration: + # Calibration file exists, ask user whether to use it or run new calibration + user_input = input( + f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: " + ) + if user_input.strip().lower() != "c": + logger.info(f"Writing calibration file associated with the id {self.id} to the motors") + self.bus.write_calibration(self.calibration) + return + + logger.info(f"\nRunning calibration for {self}") + self.bus.disable_torque() + + # Step 1: Set zero position + input( + "\nCalibration: Set Zero Position)\n" + "Position the arm in the following configuration:\n" + " - Arm hanging straight down\n" + " - Gripper closed\n" + "Press ENTER when ready..." + ) + + # Set current position as zero for all motors + self.bus.set_zero_position() + logger.info("Arm zero position set.") + + logger.info("Setting range: -90° to +90° for safety by default for all joints") + for motor_name, motor in self.bus.motors.items(): + self.calibration[motor_name] = MotorCalibration( + id=motor.id, + drive_mode=0, + homing_offset=0, + range_min=-90, + range_max=90, + ) + + self.bus.write_calibration(self.calibration) + self._save_calibration() + print(f"Calibration saved to {self.calibration_fpath}") + + def configure(self) -> None: + """Configure motors with appropriate settings.""" + # TODO(Steven, Pepijn): Slightly different from what it is happening in the leader + with self.bus.torque_disabled(): + self.bus.configure_motors() + + def setup_motors(self) -> None: + raise NotImplementedError( + "Motor ID configuration is typically done via manufacturer tools for CAN motors." + ) + + def get_observation(self) -> RobotObservation: + """ + Get current observation from robot including position, velocity, and torque. + + Reads all motor states (pos/vel/torque) in one CAN refresh cycle + instead of 3 separate reads. + """ + start = time.perf_counter() + + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + obs_dict: dict[str, Any] = {} + + states = self.bus.sync_read_all_states() + + for motor in self.bus.motors: + state = states.get(motor, {}) + obs_dict[f"{motor}.pos"] = state.get("position", 0.0) + obs_dict[f"{motor}.vel"] = state.get("velocity", 0.0) + obs_dict[f"{motor}.torque"] = state.get("torque", 0.0) + + # Capture images from cameras + for cam_key, cam in self.cameras.items(): + start = time.perf_counter() + obs_dict[cam_key] = cam.async_read() + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") + + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} get_observation took: {dt_ms:.1f}ms") + + return obs_dict + + def send_action( + self, + action: RobotAction, + custom_kp: dict[str, float] | None = None, + custom_kd: dict[str, float] | None = None, + ) -> RobotAction: + """ + Send action command to robot. + + The action magnitude may be clipped based on safety limits. + + Args: + action: Dictionary with motor positions (e.g., "joint_1.pos", "joint_2.pos") + custom_kp: Optional custom kp gains per motor (e.g., {"joint_1": 120.0, "joint_2": 150.0}) + custom_kd: Optional custom kd gains per motor (e.g., {"joint_1": 1.5, "joint_2": 2.0}) + + Returns: + The action actually sent (potentially clipped) + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} + + # Apply joint limit clipping to arm + for motor_name, position in goal_pos.items(): + if motor_name in self.config.joint_limits: + min_limit, max_limit = self.config.joint_limits[motor_name] + clipped_position = max(min_limit, min(max_limit, position)) + if clipped_position != position: + logger.debug(f"Clipped {motor_name} from {position:.2f}° to {clipped_position:.2f}°") + goal_pos[motor_name] = clipped_position + + # Cap goal position when too far away from present position. + # /!\ Slower fps expected due to reading from the follower. + if self.config.max_relative_target is not None: + present_pos = self.bus.sync_read("Present_Position") + goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()} + goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target) + + # TODO(Steven, Pepijn): Refactor writing + # Motor name to index mapping for gains + motor_index = { + "joint_1": 0, + "joint_2": 1, + "joint_3": 2, + "joint_4": 3, + "joint_5": 4, + "joint_6": 5, + "joint_7": 6, + "gripper": 7, + } + + # Use batch MIT control for arm (sends all commands, then collects responses) + commands = {} + for motor_name, position_degrees in goal_pos.items(): + idx = motor_index.get(motor_name, 0) + # Use custom gains if provided, otherwise use config defaults + if custom_kp is not None and motor_name in custom_kp: + kp = custom_kp[motor_name] + else: + kp = ( + self.config.position_kp[idx] + if isinstance(self.config.position_kp, list) + else self.config.position_kp + ) + if custom_kd is not None and motor_name in custom_kd: + kd = custom_kd[motor_name] + else: + kd = ( + self.config.position_kd[idx] + if isinstance(self.config.position_kd, list) + else self.config.position_kd + ) + commands[motor_name] = (kp, kd, position_degrees, 0.0, 0.0) + + self.bus._mit_control_batch(commands) + + return {f"{motor}.pos": val for motor, val in goal_pos.items()} + + def disconnect(self): + """Disconnect from robot.""" + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + # Disconnect CAN bus + self.bus.disconnect(self.config.disable_torque_on_disconnect) + + # Disconnect cameras + for cam in self.cameras.values(): + cam.disconnect() + + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/robots/utils.py b/src/lerobot/robots/utils.py index 27abaaa86..e0c76cab3 100644 --- a/src/lerobot/robots/utils.py +++ b/src/lerobot/robots/utils.py @@ -60,6 +60,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot: from .reachy2 import Reachy2Robot return Reachy2Robot(config) + elif config.type == "openarm_follower": + from .openarm_follower import OpenArmFollower + + return OpenArmFollower(config) elif config.type == "mock_robot": from tests.mocks.mock_robot import MockRobot diff --git a/src/lerobot/scripts/lerobot_calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py index cbc7684d3..0f79e6aa2 100644 --- a/src/lerobot/scripts/lerobot_calibrate.py +++ b/src/lerobot/scripts/lerobot_calibrate.py @@ -42,6 +42,7 @@ from lerobot.robots import ( # noqa: F401 lekiwi, make_robot_from_config, omx_follower, + openarm_follower, so_follower, ) from lerobot.teleoperators import ( # noqa: F401 @@ -52,6 +53,7 @@ from lerobot.teleoperators import ( # noqa: F401 koch_leader, make_teleoperator_from_config, omx_leader, + openarm_leader, so_leader, ) from lerobot.utils.import_utils import register_third_party_plugins diff --git a/src/lerobot/scripts/lerobot_find_joint_limits.py b/src/lerobot/scripts/lerobot_find_joint_limits.py index 20bbc8615..d928dc5cd 100644 --- a/src/lerobot/scripts/lerobot_find_joint_limits.py +++ b/src/lerobot/scripts/lerobot_find_joint_limits.py @@ -48,6 +48,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, make_robot_from_config, omx_follower, + openarm_follower, so_follower, ) from lerobot.teleoperators import ( # noqa: F401 @@ -57,6 +58,7 @@ from lerobot.teleoperators import ( # noqa: F401 koch_leader, make_teleoperator_from_config, omx_leader, + openarm_leader, so_leader, ) from lerobot.utils.robot_utils import precise_sleep diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index f03776989..4d334f38f 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -104,6 +104,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, make_robot_from_config, omx_follower, + openarm_follower, reachy2, so_follower, unitree_g1, @@ -116,6 +117,7 @@ from lerobot.teleoperators import ( # noqa: F401 koch_leader, make_teleoperator_from_config, omx_leader, + openarm_leader, reachy2_teleoperator, so_leader, ) diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index 49c06d643..c3bc3d766 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -59,6 +59,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, make_robot_from_config, omx_follower, + openarm_follower, reachy2, so_follower, unitree_g1, diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index 18d8863d6..a415dd600 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -76,6 +76,7 @@ from lerobot.robots import ( # noqa: F401 koch_follower, make_robot_from_config, omx_follower, + openarm_follower, reachy2, so_follower, ) @@ -89,6 +90,7 @@ from lerobot.teleoperators import ( # noqa: F401 koch_leader, make_teleoperator_from_config, omx_leader, + openarm_leader, reachy2_teleoperator, so_leader, ) diff --git a/src/lerobot/teleoperators/openarm_leader/__init__.py b/src/lerobot/teleoperators/openarm_leader/__init__.py new file mode 100644 index 000000000..1493317fe --- /dev/null +++ b/src/lerobot/teleoperators/openarm_leader/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_openarm_leader import OpenArmLeaderConfig +from .openarm_leader import OpenArmLeader + +__all__ = ["OpenArmLeader", "OpenArmLeaderConfig"] diff --git a/src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py b/src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py new file mode 100644 index 000000000..c53169b0a --- /dev/null +++ b/src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("openarm_leader") +@dataclass +class OpenArmLeaderConfig(TeleoperatorConfig): + """Configuration for the OpenArms leader/teleoperator with Damiao motors.""" + + # CAN interfaces - one per arm + # Arm CAN interface (e.g., "can3") + # Linux: "can0", "can1", etc. + port: str + + # CAN interface type: "socketcan" (Linux), "slcan" (serial), or "auto" (auto-detect) + can_interface: str = "socketcan" + + # CAN FD settings (OpenArms uses CAN FD by default) + use_can_fd: bool = True + can_bitrate: int = 1000000 # Nominal bitrate (1 Mbps) + can_data_bitrate: int = 5000000 # Data bitrate for CAN FD (5 Mbps) + + # Motor configuration for OpenArms (7 DOF per arm) + # Maps motor names to (send_can_id, recv_can_id, motor_type) + # Based on: https://docs.openarm.dev/software/setup/configure-test + # OpenArms uses 4 types of motors: + # - DM8009 (DM-J8009P-2EC) for shoulders (high torque) + # - DM4340P and DM4340 for shoulder rotation and elbow + # - DM4310 (DM-J4310-2EC V1.1) for wrist and gripper + motor_config: dict[str, tuple[int, int, str]] = field( + default_factory=lambda: { + "joint_1": (0x01, 0x11, "dm8009"), # J1 - Shoulder pan (DM8009) + "joint_2": (0x02, 0x12, "dm8009"), # J2 - Shoulder lift (DM8009) + "joint_3": (0x03, 0x13, "dm4340"), # J3 - Shoulder rotation (DM4340) + "joint_4": (0x04, 0x14, "dm4340"), # J4 - Elbow flex (DM4340) + "joint_5": (0x05, 0x15, "dm4310"), # J5 - Wrist roll (DM4310) + "joint_6": (0x06, 0x16, "dm4310"), # J6 - Wrist pitch (DM4310) + "joint_7": (0x07, 0x17, "dm4310"), # J7 - Wrist rotation (DM4310) + "gripper": (0x08, 0x18, "dm4310"), # J8 - Gripper (DM4310) + } + ) + + # Torque mode settings for manual control + # When enabled, motors have torque disabled for manual movement + manual_control: bool = True + + # TODO(Steven, Pepijn): Not used ... ? + # MIT control parameters (used when manual_control=False for torque control) + # List of 8 values: [joint_1, joint_2, joint_3, joint_4, joint_5, joint_6, joint_7, gripper] + position_kp: list[float] = field( + default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0, 16.0] + ) + position_kd: list[float] = field(default_factory=lambda: [3.0, 3.0, 3.0, 3.0, 0.2, 0.2, 0.2, 0.2]) diff --git a/src/lerobot/teleoperators/openarm_leader/openarm_leader.py b/src/lerobot/teleoperators/openarm_leader/openarm_leader.py new file mode 100644 index 000000000..edf4d7090 --- /dev/null +++ b/src/lerobot/teleoperators/openarm_leader/openarm_leader.py @@ -0,0 +1,225 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from typing import Any + +from lerobot.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.motors.damiao import DamiaoMotorsBus +from lerobot.processor import RobotAction +from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +from ..teleoperator import Teleoperator +from .config_openarm_leader import OpenArmLeaderConfig + +logger = logging.getLogger(__name__) + + +class OpenArmLeader(Teleoperator): + """ + OpenArm Leader/Teleoperator Arm with Damiao motors. + + This teleoperator uses CAN bus communication to read positions from + Damiao motors that are manually moved (torque disabled). + """ + + config_class = OpenArmLeaderConfig + name = "openarm_leader" + + def __init__(self, config: OpenArmLeaderConfig): + super().__init__(config) + self.config = config + + # Arm motors + motors: dict[str, Motor] = {} + for motor_name, (send_id, recv_id, motor_type_str) in config.motor_config.items(): + motor = Motor( + send_id, motor_type_str, MotorNormMode.DEGREES + ) # Always use degrees for Damiao motors + motor.recv_id = recv_id + motor.motor_type_str = motor_type_str + motors[motor_name] = motor + + self.bus = DamiaoMotorsBus( + port=self.config.port, + motors=motors, + calibration=self.calibration, + can_interface=self.config.can_interface, + use_can_fd=self.config.use_can_fd, + bitrate=self.config.can_bitrate, + data_bitrate=self.config.can_data_bitrate if self.config.use_can_fd else None, + ) + + @property + def action_features(self) -> dict[str, type]: + """Features produced by this teleoperator.""" + features: dict[str, type] = {} + for motor in self.bus.motors: + features[f"{motor}.pos"] = float + features[f"{motor}.vel"] = float + features[f"{motor}.torque"] = float + return features + + @property + def feedback_features(self) -> dict[str, type]: + """Feedback features (not implemented for OpenArms).""" + return {} + + @property + def is_connected(self) -> bool: + """Check if teleoperator is connected.""" + return self.bus.is_connected + + def connect(self, calibrate: bool = True) -> None: + """ + Connect to the teleoperator. + + For manual control, we disable torque after connecting so the + arm can be moved by hand. + """ + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + # Connect to CAN bus + logger.info(f"Connecting arm on {self.config.port}...") + self.bus.connect() + + # Run calibration if needed + if not self.is_calibrated and calibrate: + logger.info( + "Mismatch between calibration values in the motor and the calibration file or no calibration file found" + ) + self.calibrate() + + self.configure() + + if self.is_calibrated: + self.bus.set_zero_position() + + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + """Check if teleoperator is calibrated.""" + return self.bus.is_calibrated + + def calibrate(self) -> None: + """ + Run calibration procedure for OpenArms leader. + + The calibration procedure: + 1. Disable torque (if not already disabled) + 2. Ask user to position arm in zero position (hanging with gripper closed) + 3. Set this as zero position + 4. Record range of motion for each joint + 5. Save calibration + """ + if self.calibration: + # Calibration file exists, ask user whether to use it or run new calibration + user_input = input( + f"Press ENTER to use provided calibration file associated with the id {self.id}, or type 'c' and press ENTER to run calibration: " + ) + if user_input.strip().lower() != "c": + logger.info(f"Writing calibration file associated with the id {self.id} to the motors") + self.bus.write_calibration(self.calibration) + return + + logger.info(f"\nRunning calibration for {self}") + self.bus.disable_torque() + + # Step 1: Set zero position + input( + "\nCalibration: Set Zero Position)\n" + "Position the arm in the following configuration:\n" + " - Arm hanging straight down\n" + " - Gripper closed\n" + "Press ENTER when ready..." + ) + + # Set current position as zero for all motors + self.bus.set_zero_position() + logger.info("Arm zero position set.") + + logger.info("Setting range: -90° to +90° by default for all joints") + # TODO(Steven, Pepijn): Check if MotorCalibration is actually needed here given that we only use Degrees + for motor_name, motor in self.bus.motors.items(): + self.calibration[motor_name] = MotorCalibration( + id=motor.id, + drive_mode=0, + homing_offset=0, + range_min=-90, + range_max=90, + ) + + self.bus.write_calibration(self.calibration) + self._save_calibration() + print(f"Calibration saved to {self.calibration_fpath}") + + def configure(self) -> None: + """ + Configure motors for manual teleoperation. + + For manual control, we disable torque so the arm can be moved by hand. + """ + + return self.bus.disable_torque() if self.config.manual_control else self.bus.configure_motors() + + def setup_motors(self) -> None: + raise NotImplementedError( + "Motor ID configuration is typically done via manufacturer tools for CAN motors." + ) + + def get_action(self) -> RobotAction: + """ + Get current action from the leader arm. + + This is the main method for teleoperators - it reads the current state + of the leader arm and returns it as an action that can be sent to a follower. + + Reads all motor states (pos/vel/torque) in one CAN refresh cycle. + """ + start = time.perf_counter() + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + action_dict: dict[str, Any] = {} + + # Use sync_read_all_states to get pos/vel/torque in one go + states = self.bus.sync_read_all_states() + for motor in self.bus.motors: + state = states.get(motor, {}) + action_dict[f"{motor}.pos"] = state.get("position") + action_dict[f"{motor}.vel"] = state.get("velocity") + action_dict[f"{motor}.torque"] = state.get("torque") + + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read state: {dt_ms:.1f}ms") + + return action_dict + + def send_feedback(self, feedback: dict[str, float]) -> None: + raise NotImplementedError("Feedback is not yet implemented for OpenArm leader.") + + def disconnect(self) -> None: + """Disconnect from teleoperator.""" + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + # Disconnect CAN bus + # For manual control, ensure torque is disabled before disconnecting + self.bus.disconnect(disable_torque=self.config.manual_control) + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/teleoperators/utils.py b/src/lerobot/teleoperators/utils.py index eec2f119c..8f6bbc787 100644 --- a/src/lerobot/teleoperators/utils.py +++ b/src/lerobot/teleoperators/utils.py @@ -13,12 +13,14 @@ # limitations under the License. from enum import Enum -from typing import cast +from typing import TYPE_CHECKING, cast from lerobot.utils.import_utils import make_device_from_device_class from .config import TeleoperatorConfig -from .teleoperator import Teleoperator + +if TYPE_CHECKING: + from .teleoperator import Teleoperator class TeleopEvents(Enum): @@ -31,7 +33,7 @@ class TeleopEvents(Enum): TERMINATE_EPISODE = "terminate_episode" -def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator: +def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator": # TODO(Steven): Consider just using the make_device_from_device_class for all types if config.type == "keyboard": from .keyboard import KeyboardTeleop @@ -81,8 +83,12 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator: from .reachy2_teleoperator import Reachy2Teleoperator return Reachy2Teleoperator(config) + elif config.type == "openarm_leader": + from .openarm_leader import OpenArmLeader + + return OpenArmLeader(config) else: try: - return cast(Teleoperator, make_device_from_device_class(config)) + return cast("Teleoperator", make_device_from_device_class(config)) except Exception as e: raise ValueError(f"Error creating robot with config {config}: {e}") from e From 149628dfd5b3079a2c0f80832deac1da3e7bd287 Mon Sep 17 00:00:00 2001 From: Martino Russi <77496684+nepyope@users.noreply.github.com> Date: Wed, 28 Jan 2026 15:17:38 +0100 Subject: [PATCH 068/111] add g1 teleoperation (#2791) * add gravity compensation * add g1 teleoperation --------- Co-authored-by: Michel Aractingi --- docs/source/unitree_g1.mdx | 100 +++- pyproject.toml | 6 +- .../robots/unitree_g1/config_unitree_g1.py | 3 + src/lerobot/robots/unitree_g1/g1_utils.py | 2 +- .../unitree_g1/robot_kinematic_processor.py | 313 ++++++++++++ src/lerobot/robots/unitree_g1/unitree_g1.py | 19 +- src/lerobot/scripts/lerobot_calibrate.py | 1 + src/lerobot/scripts/lerobot_record.py | 3 +- src/lerobot/scripts/lerobot_teleoperate.py | 2 + .../teleoperators/unitree_g1/__init__.py | 21 + .../unitree_g1/config_unitree_g1.py | 37 ++ .../teleoperators/unitree_g1/exo_calib.py | 446 ++++++++++++++++++ .../teleoperators/unitree_g1/exo_ik.py | 353 ++++++++++++++ .../teleoperators/unitree_g1/exo_serial.py | 119 +++++ .../teleoperators/unitree_g1/unitree_g1.py | 157 ++++++ src/lerobot/teleoperators/utils.py | 4 + 16 files changed, 1581 insertions(+), 5 deletions(-) create mode 100644 src/lerobot/robots/unitree_g1/robot_kinematic_processor.py create mode 100644 src/lerobot/teleoperators/unitree_g1/__init__.py create mode 100644 src/lerobot/teleoperators/unitree_g1/config_unitree_g1.py create mode 100644 src/lerobot/teleoperators/unitree_g1/exo_calib.py create mode 100644 src/lerobot/teleoperators/unitree_g1/exo_ik.py create mode 100644 src/lerobot/teleoperators/unitree_g1/exo_serial.py create mode 100644 src/lerobot/teleoperators/unitree_g1/unitree_g1.py diff --git a/docs/source/unitree_g1.mdx b/docs/source/unitree_g1.mdx index e6bffdf1b..ea6bf54ad 100644 --- a/docs/source/unitree_g1.mdx +++ b/docs/source/unitree_g1.mdx @@ -188,7 +188,105 @@ Press `Ctrl+C` to stop the policy. ## Running in Simulation Mode (MuJoCo) -You can now test policies before unleashing them on the physical robot using MuJoCo. To do so simply set `is_simulation=True` in config. +You can test policies before deploying on the physical robot using MuJoCo simulation. Set `is_simulation=True` in config or pass `--robot.is_simulation=true` via CLI. + +### Calibrate Exoskeleton Teleoperator + +```bash +lerobot-calibrate \ + --teleop.type=unitree_g1 \ + --teleop.left_arm_config.port=/dev/ttyACM1 \ + --teleop.right_arm_config.port=/dev/ttyACM0 \ + --teleop.id=exo +``` + +### Teleoperate in Simulation + +```bash +lerobot-teleoperate \ + --robot.type=unitree_g1 \ + --robot.is_simulation=true \ + --teleop.type=unitree_g1 \ + --teleop.left_arm_config.port=/dev/ttyACM1 \ + --teleop.right_arm_config.port=/dev/ttyACM0 \ + --teleop.id=exo \ + --fps=100 +``` + +### Record Dataset in Simulation + +```bash +python -m lerobot.scripts.lerobot_record \ + --robot.type=unitree_g1 \ + --robot.is_simulation=true \ + --robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \ + --teleop.type=unitree_g1 \ + --teleop.left_arm_config.port=/dev/ttyACM1 \ + --teleop.right_arm_config.port=/dev/ttyACM0 \ + --teleop.id=exo \ + --dataset.repo_id=your-username/dataset-name \ + --dataset.single_task="Test" \ + --dataset.num_episodes=2 \ + --dataset.episode_time_s=5 \ + --dataset.reset_time_s=5 \ + --dataset.push_to_hub=true +``` + +Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim) + +--- + +## Running on Real Robot + +Once the robot server is running on the G1 (see Part 3), you can teleoperate and record on the real robot. + +### Start the Camera Server + +On the robot, start the ZMQ image server: + +```bash +python src/lerobot/cameras/zmq/image_server.py +``` + +Keep this running in a separate terminal for camera streaming during recording. + +### Teleoperate Real Robot + +```bash +lerobot-teleoperate \ + --robot.type=unitree_g1 \ + --robot.is_simulation=false \ + --teleop.type=unitree_g1 \ + --teleop.left_arm_config.port=/dev/ttyACM1 \ + --teleop.right_arm_config.port=/dev/ttyACM0 \ + --teleop.id=exo \ + --fps=100 +``` + +### Record Dataset on Real Robot + +```bash +python -m lerobot.scripts.lerobot_record \ + --robot.type=unitree_g1 \ + --robot.is_simulation=false \ + --robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \ + --teleop.type=unitree_g1 \ + --teleop.left_arm_config.port=/dev/ttyACM1 \ + --teleop.right_arm_config.port=/dev/ttyACM0 \ + --teleop.id=exo \ + --dataset.repo_id=your-username/dataset-name \ + --dataset.single_task="Test" \ + --dataset.num_episodes=2 \ + --dataset.episode_time_s=5 \ + --dataset.reset_time_s=5 \ + --dataset.push_to_hub=true +``` + +**Note**: Update `server_address` to match your robot's camera server IP. + +Example real robot dataset: [nepyope/teleop_test_real](https://huggingface.co/datasets/nepyope/teleop_test_real) + +--- ## Additional Resources diff --git a/pyproject.toml b/pyproject.toml index ea2dfb4a2..210d70b6b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,7 +111,11 @@ hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"] lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"] unitree_g1 = [ "pyzmq>=26.2.1,<28.0.0", - "onnxruntime>=1.16.0,<2.0.0" + "onnxruntime>=1.16.0,<2.0.0", + "pin>=3.0.0,<4.0.0", + "meshcat>=0.3.0,<0.4.0", + "matplotlib>=3.9.0,<4.0.0", + "casadi>=3.6.0,<4.0.0", ] reachy2 = ["reachy2_sdk>=1.0.15,<1.1.0"] kinematics = ["lerobot[placo-dep]"] diff --git a/src/lerobot/robots/unitree_g1/config_unitree_g1.py b/src/lerobot/robots/unitree_g1/config_unitree_g1.py index 0b163019d..1b81214a6 100644 --- a/src/lerobot/robots/unitree_g1/config_unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/config_unitree_g1.py @@ -65,3 +65,6 @@ class UnitreeG1Config(RobotConfig): # Cameras (ZMQ-based remote cameras) cameras: dict[str, CameraConfig] = field(default_factory=dict) + + # Compensates for gravity on the unitree's arms using the arm ik solver + gravity_compensation: bool = False diff --git a/src/lerobot/robots/unitree_g1/g1_utils.py b/src/lerobot/robots/unitree_g1/g1_utils.py index 3c41ee985..4e37bdcef 100644 --- a/src/lerobot/robots/unitree_g1/g1_utils.py +++ b/src/lerobot/robots/unitree_g1/g1_utils.py @@ -18,7 +18,7 @@ from enum import IntEnum # ruff: noqa: N801, N815 -NUM_MOTORS = 35 +NUM_MOTORS = 29 class G1_29_JointArmIndex(IntEnum): diff --git a/src/lerobot/robots/unitree_g1/robot_kinematic_processor.py b/src/lerobot/robots/unitree_g1/robot_kinematic_processor.py new file mode 100644 index 000000000..d086a9986 --- /dev/null +++ b/src/lerobot/robots/unitree_g1/robot_kinematic_processor.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys + +import numpy as np + +logger = logging.getLogger(__name__) +parent2_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(parent2_dir) + + +class WeightedMovingFilter: + def __init__(self, weights, data_size=14): + self._window_size = len(weights) + self._weights = np.array(weights) + self._data_size = data_size + self._filtered_data = np.zeros(self._data_size) + self._data_queue = [] + + def _apply_filter(self): + if len(self._data_queue) < self._window_size: + return self._data_queue[-1] + + data_array = np.array(self._data_queue) + temp_filtered_data = np.zeros(self._data_size) + for i in range(self._data_size): + temp_filtered_data[i] = np.convolve(data_array[:, i], self._weights, mode="valid")[-1] + + return temp_filtered_data + + def add_data(self, new_data): + assert len(new_data) == self._data_size + + if len(self._data_queue) > 0 and np.array_equal( + new_data, self._data_queue[-1] + ): # skip duplicate data + return + + if len(self._data_queue) >= self._window_size: + self._data_queue.pop(0) + + self._data_queue.append(new_data) + self._filtered_data = self._apply_filter() + + @property + def filtered_data(self): + return self._filtered_data + + +class G1_29_ArmIK: # noqa: N801 + def __init__(self, unit_test=False): + import casadi + import pinocchio as pin + from huggingface_hub import snapshot_download + from pinocchio import casadi as cpin + + self._pin = pin + np.set_printoptions(precision=5, suppress=True, linewidth=200) + + self.unit_test = unit_test + + self.repo_path = snapshot_download("lerobot/unitree-g1-mujoco") + urdf_path = os.path.join(self.repo_path, "assets", "g1_body29_hand14.urdf") + mesh_dir = os.path.join(self.repo_path, "assets") + + self.robot = self._pin.RobotWrapper.BuildFromURDF(urdf_path, mesh_dir) + + self.mixed_jointsToLockIDs = [ + "left_hip_pitch_joint", + "left_hip_roll_joint", + "left_hip_yaw_joint", + "left_knee_joint", + "left_ankle_pitch_joint", + "left_ankle_roll_joint", + "right_hip_pitch_joint", + "right_hip_roll_joint", + "right_hip_yaw_joint", + "right_knee_joint", + "right_ankle_pitch_joint", + "right_ankle_roll_joint", + "waist_yaw_joint", + "waist_roll_joint", + "waist_pitch_joint", + "left_hand_thumb_0_joint", + "left_hand_thumb_1_joint", + "left_hand_thumb_2_joint", + "left_hand_middle_0_joint", + "left_hand_middle_1_joint", + "left_hand_index_0_joint", + "left_hand_index_1_joint", + "right_hand_thumb_0_joint", + "right_hand_thumb_1_joint", + "right_hand_thumb_2_joint", + "right_hand_index_0_joint", + "right_hand_index_1_joint", + "right_hand_middle_0_joint", + "right_hand_middle_1_joint", + ] + + self.reduced_robot = self.robot.buildReducedRobot( + list_of_joints_to_lock=self.mixed_jointsToLockIDs, + reference_configuration=np.array([0.0] * self.robot.model.nq), + ) + + # Arm joint names in G1 motor order (G1_29_JointArmIndex) + self._arm_joint_names_g1 = [ + "left_shoulder_pitch_joint", + "left_shoulder_roll_joint", + "left_shoulder_yaw_joint", + "left_elbow_joint", + "left_wrist_roll_joint", + "left_wrist_pitch_joint", + "left_wrist_yaw_joint", + "right_shoulder_pitch_joint", + "right_shoulder_roll_joint", + "right_shoulder_yaw_joint", + "right_elbow_joint", + "right_wrist_roll_joint", + "right_wrist_pitch_joint", + "right_wrist_yaw_joint", + ] + # Pinocchio uses its own joint order in q; build index mapping. + self._arm_joint_names_pin = sorted( + self._arm_joint_names_g1, + key=lambda name: self.reduced_robot.model.idx_qs[self.reduced_robot.model.getJointId(name)], + ) + logger.info(f"Pinocchio arm joint order: {self._arm_joint_names_pin}") + self._arm_reorder_g1_to_pin = [ + self._arm_joint_names_g1.index(name) for name in self._arm_joint_names_pin + ] + # Inverse mapping to return tau in G1 motor order. + self._arm_reorder_pin_to_g1 = np.argsort(self._arm_reorder_g1_to_pin) + + self.reduced_robot.model.addFrame( + self._pin.Frame( + "L_ee", + self.reduced_robot.model.getJointId("left_wrist_yaw_joint"), + self._pin.SE3(np.eye(3), np.array([0.05, 0, 0]).T), + self._pin.FrameType.OP_FRAME, + ) + ) + + self.reduced_robot.model.addFrame( + self._pin.Frame( + "R_ee", + self.reduced_robot.model.getJointId("right_wrist_yaw_joint"), + self._pin.SE3(np.eye(3), np.array([0.05, 0, 0]).T), + self._pin.FrameType.OP_FRAME, + ) + ) + + # Creating Casadi models and data for symbolic computing + self.cmodel = cpin.Model(self.reduced_robot.model) + self.cdata = self.cmodel.createData() + + # Creating symbolic variables + self.cq = casadi.SX.sym("q", self.reduced_robot.model.nq, 1) + self.cTf_l = casadi.SX.sym("tf_l", 4, 4) + self.cTf_r = casadi.SX.sym("tf_r", 4, 4) + cpin.framesForwardKinematics(self.cmodel, self.cdata, self.cq) + + # Get the hand joint ID and define the error function + self.L_hand_id = self.reduced_robot.model.getFrameId("L_ee") + self.R_hand_id = self.reduced_robot.model.getFrameId("R_ee") + + self.translational_error = casadi.Function( + "translational_error", + [self.cq, self.cTf_l, self.cTf_r], + [ + casadi.vertcat( + self.cdata.oMf[self.L_hand_id].translation - self.cTf_l[:3, 3], + self.cdata.oMf[self.R_hand_id].translation - self.cTf_r[:3, 3], + ) + ], + ) + self.rotational_error = casadi.Function( + "rotational_error", + [self.cq, self.cTf_l, self.cTf_r], + [ + casadi.vertcat( + cpin.log3(self.cdata.oMf[self.L_hand_id].rotation @ self.cTf_l[:3, :3].T), + cpin.log3(self.cdata.oMf[self.R_hand_id].rotation @ self.cTf_r[:3, :3].T), + ) + ], + ) + + # Defining the optimization problem + self.opti = casadi.Opti() + self.var_q = self.opti.variable(self.reduced_robot.model.nq) + self.var_q_last = self.opti.parameter(self.reduced_robot.model.nq) # for smooth + self.param_tf_l = self.opti.parameter(4, 4) + self.param_tf_r = self.opti.parameter(4, 4) + self.translational_cost = casadi.sumsqr( + self.translational_error(self.var_q, self.param_tf_l, self.param_tf_r) + ) + self.rotation_cost = casadi.sumsqr( + self.rotational_error(self.var_q, self.param_tf_l, self.param_tf_r) + ) + self.regularization_cost = casadi.sumsqr(self.var_q) + self.smooth_cost = casadi.sumsqr(self.var_q - self.var_q_last) + + # Setting optimization constraints and goals + self.opti.subject_to( + self.opti.bounded( + self.reduced_robot.model.lowerPositionLimit, + self.var_q, + self.reduced_robot.model.upperPositionLimit, + ) + ) + self.opti.minimize( + 50 * self.translational_cost + + self.rotation_cost + + 0.02 * self.regularization_cost + + 0.1 * self.smooth_cost + ) + + opts = { + "ipopt": {"print_level": 0, "max_iter": 50, "tol": 1e-6}, + "print_time": False, # print or not + "calc_lam_p": False, # https://github.com/casadi/casadi/wiki/FAQ:-Why-am-I-getting-%22NaN-detected%22in-my-optimization%3F + } + self.opti.solver("ipopt", opts) + + self.init_data = np.zeros(self.reduced_robot.model.nq) + self.smooth_filter = WeightedMovingFilter(np.array([0.4, 0.3, 0.2, 0.1]), 14) + + def solve_ik(self, left_wrist, right_wrist, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None): + if current_lr_arm_motor_q is not None: + self.init_data = current_lr_arm_motor_q + self.opti.set_initial(self.var_q, self.init_data) + + self.opti.set_value(self.param_tf_l, left_wrist) + self.opti.set_value(self.param_tf_r, right_wrist) + self.opti.set_value(self.var_q_last, self.init_data) # for smooth + + try: + self.opti.solve() + + sol_q = self.opti.value(self.var_q) + self.smooth_filter.add_data(sol_q) + sol_q = self.smooth_filter.filtered_data + + if current_lr_arm_motor_dq is not None: + v = current_lr_arm_motor_dq * 0.0 + else: + v = (sol_q - self.init_data) * 0.0 + + self.init_data = sol_q + + sol_tauff = self._pin.rnea( + self.reduced_robot.model, + self.reduced_robot.data, + sol_q, + v, + np.zeros(self.reduced_robot.model.nv), + ) + + return sol_q, sol_tauff + + except Exception as e: + logger.error(f"ERROR in convergence, plotting debug info.{e}") + + sol_q = self.opti.debug.value(self.var_q) + self.smooth_filter.add_data(sol_q) + sol_q = self.smooth_filter.filtered_data + + if current_lr_arm_motor_dq is not None: + v = current_lr_arm_motor_dq * 0.0 + else: + v = (sol_q - self.init_data) * 0.0 + + self.init_data = sol_q + + logger.error( + f"sol_q:{sol_q} \nmotorstate: \n{current_lr_arm_motor_q} \nleft_pose: \n{left_wrist} \nright_pose: \n{right_wrist}" + ) + + return current_lr_arm_motor_q, np.zeros(self.reduced_robot.model.nv) + + def solve_tau(self, current_lr_arm_motor_q=None, current_lr_arm_motor_dq=None): + try: + q_g1 = np.array(current_lr_arm_motor_q, dtype=float) + if q_g1.shape[0] != len(self._arm_joint_names_g1): + raise ValueError(f"Expected {len(self._arm_joint_names_g1)} arm joints, got {q_g1.shape[0]}") + q_pin = q_g1[self._arm_reorder_g1_to_pin] + sol_tauff = self._pin.rnea( + self.reduced_robot.model, + self.reduced_robot.data, + q_pin, + np.zeros(self.reduced_robot.model.nv), + np.zeros(self.reduced_robot.model.nv), + ) + return sol_tauff[self._arm_reorder_pin_to_g1] + + except Exception as e: + logger.error(f"ERROR in convergence, plotting debug info.{e}") + return np.zeros(self.reduced_robot.model.nv) diff --git a/src/lerobot/robots/unitree_g1/unitree_g1.py b/src/lerobot/robots/unitree_g1/unitree_g1.py index fa6e0da85..01b4f330e 100644 --- a/src/lerobot/robots/unitree_g1/unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/unitree_g1.py @@ -27,7 +27,8 @@ import numpy as np from lerobot.cameras.utils import make_cameras_from_configs from lerobot.envs.factory import make_env from lerobot.processor import RobotAction, RobotObservation -from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex +from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex, G1_29_JointIndex +from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK from ..robot import Robot from .config_unitree_g1 import UnitreeG1Config @@ -127,6 +128,8 @@ class UnitreeG1(Robot): self.subscribe_thread = None self.remote_controller = self.RemoteController() + self.arm_ik = G1_29_ArmIK() + def _subscribe_motor_state(self): # polls robot state @ 250Hz while not self._shutdown_event.is_set(): start_time = time.time() @@ -361,6 +364,20 @@ class UnitreeG1(Robot): self.msg.motor_cmd[motor.value].kd = self.kd[motor.value] self.msg.motor_cmd[motor.value].tau = 0 + if self.config.gravity_compensation: + # Build action_np from motor commands (arm joints are indices 15-28, local indices 0-13) + action_np = np.zeros(14) + arm_start_idx = G1_29_JointArmIndex.kLeftShoulderPitch.value # 15 + for joint in G1_29_JointArmIndex: + local_idx = joint.value - arm_start_idx + action_np[local_idx] = self.msg.motor_cmd[joint.value].q + tau = self.arm_ik.solve_tau(action_np) + + # Apply tau back to motor commands + for joint in G1_29_JointArmIndex: + local_idx = joint.value - arm_start_idx + self.msg.motor_cmd[joint.value].tau = tau[local_idx] + self.msg.crc = self.crc.Crc(self.msg) self.lowcmd_publisher.Write(self.msg) return action diff --git a/src/lerobot/scripts/lerobot_calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py index 0f79e6aa2..2fa1b2a03 100644 --- a/src/lerobot/scripts/lerobot_calibrate.py +++ b/src/lerobot/scripts/lerobot_calibrate.py @@ -55,6 +55,7 @@ from lerobot.teleoperators import ( # noqa: F401 omx_leader, openarm_leader, so_leader, + unitree_g1, ) from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.utils import init_logging diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 4d334f38f..d621189e8 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -107,7 +107,7 @@ from lerobot.robots import ( # noqa: F401 openarm_follower, reachy2, so_follower, - unitree_g1, + unitree_g1 as unitree_g1_robot, ) from lerobot.teleoperators import ( # noqa: F401 Teleoperator, @@ -120,6 +120,7 @@ from lerobot.teleoperators import ( # noqa: F401 openarm_leader, reachy2_teleoperator, so_leader, + unitree_g1, ) from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop from lerobot.utils.constants import ACTION, OBS_STR diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index a415dd600..958bd00ef 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -79,6 +79,7 @@ from lerobot.robots import ( # noqa: F401 openarm_follower, reachy2, so_follower, + unitree_g1 as unitree_g1_robot, ) from lerobot.teleoperators import ( # noqa: F401 Teleoperator, @@ -93,6 +94,7 @@ from lerobot.teleoperators import ( # noqa: F401 openarm_leader, reachy2_teleoperator, so_leader, + unitree_g1, ) from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.robot_utils import precise_sleep diff --git a/src/lerobot/teleoperators/unitree_g1/__init__.py b/src/lerobot/teleoperators/unitree_g1/__init__.py new file mode 100644 index 000000000..45955a0e2 --- /dev/null +++ b/src/lerobot/teleoperators/unitree_g1/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config_unitree_g1 import ExoskeletonArmPortConfig, UnitreeG1TeleoperatorConfig +from .exo_calib import ExoskeletonCalibration, ExoskeletonJointCalibration +from .exo_ik import ExoskeletonIKHelper +from .exo_serial import ExoskeletonArm +from .unitree_g1 import UnitreeG1Teleoperator diff --git a/src/lerobot/teleoperators/unitree_g1/config_unitree_g1.py b/src/lerobot/teleoperators/unitree_g1/config_unitree_g1.py new file mode 100644 index 000000000..66c4e7f31 --- /dev/null +++ b/src/lerobot/teleoperators/unitree_g1/config_unitree_g1.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from ..config import TeleoperatorConfig + + +@dataclass +class ExoskeletonArmPortConfig: + """Serial port configuration for individual exoskeleton arm.""" + + port: str = "" + baud_rate: int = 115200 + + +@TeleoperatorConfig.register_subclass("unitree_g1") +@dataclass +class UnitreeG1TeleoperatorConfig(TeleoperatorConfig): + left_arm_config: ExoskeletonArmPortConfig = field(default_factory=ExoskeletonArmPortConfig) + right_arm_config: ExoskeletonArmPortConfig = field(default_factory=ExoskeletonArmPortConfig) + + # Frozen joints (comma-separated joint names that won't be moved by IK) + frozen_joints: str = "" diff --git a/src/lerobot/teleoperators/unitree_g1/exo_calib.py b/src/lerobot/teleoperators/unitree_g1/exo_calib.py new file mode 100644 index 000000000..2927a1b55 --- /dev/null +++ b/src/lerobot/teleoperators/unitree_g1/exo_calib.py @@ -0,0 +1,446 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This module handles calibration of hall effect sensors used in the exoskeleton. +Each joint has a pair of ADC channels outputting sin and cos values that trace an ellipse +as the joint rotates due to imprecision in magnet/sensor placement. We fit this ellipse to a unit circle, +and calculate arctan2 of the unit circle to get the joint angle. +We then store the ellipse parameters and the zero offset for each joint to be used at runtime. +""" + +import json +import logging +import time +from collections import deque +from dataclasses import dataclass, field +from pathlib import Path + +import numpy as np +import serial + +logger = logging.getLogger(__name__) + + +# exoskeleton joint names -> ADC channel pairs. TODO: add wrist pitch and wrist yaw +JOINTS = { + "shoulder_pitch": (0, 1), + "shoulder_yaw": (2, 3), + "shoulder_roll": (4, 5), + "elbow_flex": (6, 7), + "wrist_roll": (14, 15), +} + + +@dataclass +class ExoskeletonJointCalibration: + name: str # joint name + center_fit: list[float] # center of the ellipse + T: list[list[float]] # 2x2 transformation matrix + zero_offset: float = 0.0 # angle at neutral pose + + +@dataclass +class ExoskeletonCalibration: + """Full calibration data for an exoskeleton arm.""" + + version: int = 2 + side: str = "" + adc_max: int = 2**12 - 1 + joints: list[ExoskeletonJointCalibration] = field(default_factory=list) + + def to_dict(self) -> dict: + return { + "version": self.version, + "side": self.side, + "adc_max": self.adc_max, + "joints": [ + { + "name": j.name, + "center_fit": j.center_fit, + "T": j.T, + "zero_offset": j.zero_offset, + } + for j in self.joints + ], + } + + @classmethod + def from_dict(cls, data: dict) -> "ExoskeletonCalibration": + joints = [ + ExoskeletonJointCalibration( + name=j["name"], + center_fit=j["center_fit"], + T=j["T"], + zero_offset=j.get("zero_offset", 0.0), + ) + for j in data.get("joints", []) + ] + return cls( + version=data.get("version", 2), + side=data.get("side", ""), + adc_max=data.get("adc_max", 2**12 - 1), + joints=joints, + ) + + +@dataclass(frozen=True) +class CalibParams: + fit_every: float = 0.15 + min_fit_points: int = 60 + fit_window: int = 900 + max_fit_points: int = 300 + trim_low: float = 0.05 + trim_high: float = 0.95 + median_window: int = 5 + history: int = 3500 + draw_hz: float = 120.0 + sample_count: int = 50 + + +def normalize_angle(angle: float) -> float: + while angle > np.pi: + angle -= 2 * np.pi + while angle < -np.pi: + angle += 2 * np.pi + return angle + + +def joint_z_and_angle(raw16: list[int], j: ExoskeletonJointCalibration) -> tuple[np.ndarray, float]: + """ + Applies calibration to each joint: raw → centered → ellipse-to-circle → angle. + """ + pair = JOINTS[j.name] + s, c = raw16[pair[0]], raw16[pair[1]] # get sin and cos + p = np.array([float(c) - (2**12 - 1) / 2, float(s) - (2**12 - 1) / 2]) # center the raw values + z = np.asarray(j.T) @ ( + p - np.asarray(j.center_fit) + ) # center the ellipse and invert the transformation matrix to get unit circle coords + ang = float(np.arctan2(z[1], z[0])) - j.zero_offset # calculate the anvgle and apply the zero offset + return z, normalize_angle(-ang) # ensure range is [-pi, pi] + + +def exo_raw_to_angles(raw16: list[int], calib: ExoskeletonCalibration) -> dict[str, float]: + """Convert raw sensor readings to joint angles using calibration.""" + return {j.name: joint_z_and_angle(raw16, j)[1] for j in calib.joints} + + +def run_exo_calibration( + ser: serial.Serial, + side: str, + save_path: Path, + params: CalibParams | None = None, +) -> ExoskeletonCalibration: + """ + Run interactive calibration for an exoskeleton arm. + """ + try: + import cv2 + import matplotlib.pyplot as plt + except ImportError as e: + raise ImportError( + "Calibration requires matplotlib and opencv-python. " + "Install with: pip install matplotlib opencv-python" + ) from e + + from .exo_serial import read_raw_from_serial + + params = params or CalibParams() + joint_list = list(JOINTS.items()) # Convert dict to list for indexing + logger.info(f"Starting calibration for {side} exoskeleton arm") + + def running_median(win: deque) -> float: + return float(np.median(np.fromiter(win, dtype=float))) + + def read_joint_point(raw16: list[int], pair: tuple[int, int]): + s, c = raw16[pair[0]], raw16[pair[1]] + return float(c) - (2**12 - 1) / 2, float(s) - (2**12 - 1) / 2, float(s), float(c) + + def select_fit_subset(xs, ys): + """Select and filter points for ellipse fitting. Trims outliers by radius and downsamples.""" + n = min(params.fit_window, len(xs)) + if n <= 0: + return None, None + x = np.asarray(list(xs)[-n:], dtype=float) # most recent n samples + y = np.asarray(list(ys)[-n:], dtype=float) + r = np.sqrt(x * x + y * y) # radius from origin + if len(r) >= 20: + lo, hi = np.quantile(r, params.trim_low), np.quantile(r, params.trim_high) # outlier bounds + keep = (r >= lo) & (r <= hi) + x, y = x[keep], y[keep] # remove outliers + if len(x) > params.max_fit_points: + idx = np.linspace(0, len(x) - 1, params.max_fit_points).astype(int) # downsample evenly + x, y = x[idx], y[idx] + return x, y + + def fit_ellipse_opencv(x, y): + """Fit ellipse to (x,y) points using OpenCV. Returns center, axes, rotation matrix, and outline.""" + x, y = np.asarray(x, dtype=float), np.asarray(y, dtype=float) + if len(x) < 5: + return None + pts = np.stack([x, y], axis=1).astype(np.float32).reshape(-1, 1, 2) + try: + (xc, yc), (w, h), angle_deg = cv2.fitEllipse(pts) # returns center, axes, rotation in degrees + except cv2.error: + return None + a, b = float(w) * 0.5, float(h) * 0.5 # get ellipse major and minor semi-axes + phi = np.deg2rad(float(angle_deg)) # to rad + if b > a: # ensure major axis is a + a, b = b, a + phi += np.pi / 2.0 + if not np.isfinite(a) or not np.isfinite(b) or a <= 1e-6 or b <= 1e-6: + return None + cp, sp = float(np.cos(phi)), float(np.sin(phi)) # + rot = np.array([[cp, -sp], [sp, cp]], dtype=float) # 2x2 rotation matrix + center = np.array([float(xc), float(yc)], dtype=float) # offset vector + tt = np.linspace(0, 2 * np.pi, 360) + outline = (rot @ np.stack([a * np.cos(tt), b * np.sin(tt)])).T + center # for viz + return {"center": center, "a": a, "b": b, "R": rot, "ex": outline[:, 0], "ey": outline[:, 1]} + + # Setup matplotlib + plt.ion() + fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(12, 6)) + ax0.set_xlabel("cos - center") + ax0.set_ylabel("sin - center") + ax0.grid(True, alpha=0.25) + ax0.set_aspect("equal", adjustable="box") + ax1.set_title("Unit circle + angle") + ax1.set_xlabel("x") + ax1.set_ylabel("y") + ax1.grid(True, alpha=0.25) + ax1.set_aspect("equal", adjustable="box") + tt = np.linspace(0, 2 * np.pi, 360) + ax1.plot(np.cos(tt), np.sin(tt), "k-", linewidth=1) + ax0.set_xlim(-2200, 2200) + ax0.set_ylim(-2200, 2200) + ax1.set_xlim(-1.4, 1.4) + ax1.set_ylim(-1.4, 1.4) + + sc0 = ax0.scatter([], [], s=6, animated=True) + (ell_line,) = ax0.plot([], [], "r-", linewidth=2, animated=True) + sc1 = ax1.scatter([], [], s=6, animated=True) + (radius_line,) = ax1.plot([], [], "g-", linewidth=2, animated=True) + angle_text = ax1.text( + 0.02, 0.98, "", transform=ax1.transAxes, va="top", ha="left", fontsize=12, animated=True + ) + + fig.canvas.draw() + bg0 = fig.canvas.copy_from_bbox(ax0.bbox) + bg1 = fig.canvas.copy_from_bbox(ax1.bbox) + + # State + joints_out = [] + joint_idx = 0 + phase = "ellipse" + advance_requested = False + zero_samples = [] + + def on_key(event): + nonlocal advance_requested + if event.key in ("n", "N", "enter", " "): + advance_requested = True + + fig.canvas.mpl_connect("key_press_event", on_key) + + def reset_state(): + return { + "xs": deque(maxlen=params.history), + "ys": deque(maxlen=params.history), + "xu": deque(maxlen=params.history), + "yu": deque(maxlen=params.history), + "win_s": deque(maxlen=params.median_window), + "win_c": deque(maxlen=params.median_window), + "ellipse_cache": None, + "T": None, + "center_fit": None, + "have_transform": False, + "latest_z": None, + "last_fit": 0.0, + } + + state = reset_state() + last_draw = 0.0 + name, pair = joint_list[joint_idx] + fig.canvas.manager.set_window_title(f"[{joint_idx + 1}/{len(joint_list)}] {name} - ELLIPSE") + ax0.set_title(f"{name} raw (filtered)") + logger.info(f"[{joint_idx + 1}/{len(joint_list)}] Calibrating {name}") + logger.info("Step 1: Move joint around to map ellipse, then press 'n'") + + try: + while plt.fignum_exists(fig.number): + name, pair = joint_list[joint_idx] + + # Handles calibration GUI state: ellipse → zero_pose → next joint -> ellipse -> ... + if phase == "ellipse" and advance_requested and state["have_transform"]: + joints_out.append( + { + "name": name, + "center_fit": state["center_fit"].tolist(), + "T": state["T"].tolist(), + } + ) + logger.info(f" -> Ellipse saved for {name}") + phase, zero_samples, advance_requested = "zero_pose", [], False + fig.canvas.manager.set_window_title(f"[{joint_idx + 1}/{len(joint_list)}] {name} - ZERO POSE") + ax0.set_title(f"{name} - hold zero pose") + fig.canvas.draw() + bg0, bg1 = fig.canvas.copy_from_bbox(ax0.bbox), fig.canvas.copy_from_bbox(ax1.bbox) + logger.info(f"Step 2: Hold {name} in zero position, then press 'n'") + + elif phase == "ellipse" and advance_requested and not state["have_transform"]: + logger.info(" (Need valid fit first - keep moving the joint)") + advance_requested = False + + elif phase == "zero_pose" and advance_requested: + if len(zero_samples) >= params.sample_count: + zero_offset = float(np.mean(zero_samples[-params.sample_count :])) + joints_out[-1]["zero_offset"] = zero_offset + logger.info(f" -> {name} zero: {zero_offset:+.3f} rad ({np.degrees(zero_offset):+.1f}°)") + joint_idx += 1 + advance_requested = False + + if joint_idx >= len(joint_list): + # All joints done + calib = ExoskeletonCalibration( + version=2, + side=side, + adc_max=2**12 - 1, + joints=[ + ExoskeletonJointCalibration( + name=j["name"], + center_fit=j["center_fit"], + T=j["T"], + zero_offset=j.get("zero_offset", 0.0), + ) + for j in joints_out + ], + ) + save_path.parent.mkdir(parents=True, exist_ok=True) + with open(save_path, "w") as f: + json.dump(calib.to_dict(), f, indent=2) + logger.info(f"Saved calibration to {save_path}") + logger.info("Calibration complete!") + plt.close(fig) + return calib + + # Next joint + phase, state = "ellipse", reset_state() + name, pair = joint_list[joint_idx] + fig.canvas.manager.set_window_title( + f"[{joint_idx + 1}/{len(joint_list)}] {name} - ELLIPSE" + ) + ax0.set_title(f"{name} raw (filtered)") + fig.canvas.draw() + bg0, bg1 = fig.canvas.copy_from_bbox(ax0.bbox), fig.canvas.copy_from_bbox(ax1.bbox) + logger.info(f"[{joint_idx + 1}/{len(joint_list)}] Calibrating {name}") + logger.info("Step 1: Move joint around to map ellipse, then press 'n'") + else: + logger.info( + f" (Collecting samples: {len(zero_samples)}/{params.sample_count} - hold still)" + ) + advance_requested = False + + # Read sensor + raw16 = read_raw_from_serial(ser) + if raw16 is not None: + x_raw, y_raw, s_raw, c_raw = read_joint_point(raw16, pair) + + if phase == "ellipse": + if state["have_transform"]: + z = state["T"] @ (np.array([x_raw, y_raw]) - state["center_fit"]) + state["xu"].append(float(z[0])) + state["yu"].append(float(z[1])) + state["latest_z"] = (float(z[0]), float(z[1])) + state["win_s"].append(s_raw) + state["win_c"].append(c_raw) + if len(state["win_s"]) >= max(3, params.median_window): + state["ys"].append(running_median(state["win_s"]) - (2**12 - 1) / 2) + state["xs"].append(running_median(state["win_c"]) - (2**12 - 1) / 2) + else: + jdata = joints_out[-1] + z = np.array(jdata["T"]) @ (np.array([x_raw, y_raw]) - np.array(jdata["center_fit"])) + zero_samples.append(float(np.arctan2(z[1], z[0]))) + state["latest_z"] = (float(z[0]), float(z[1])) + + # Ellipse fitting + t = time.time() + if ( + phase == "ellipse" + and (t - state["last_fit"]) >= params.fit_every + and len(state["xs"]) >= params.min_fit_points + ): + xfit, yfit = select_fit_subset(state["xs"], state["ys"]) + if xfit is not None and len(xfit) >= params.min_fit_points: + fit = fit_ellipse_opencv(xfit, yfit) + if fit is not None: + state["center_fit"] = fit["center"] + state["T"] = np.diag([1.0 / fit["a"], 1.0 / fit["b"]]) @ fit["R"].T + state["ellipse_cache"] = (fit["ex"], fit["ey"]) + state["have_transform"] = True + state["last_fit"] = t + + # Drawing + if (t - last_draw) >= 1.0 / params.draw_hz: + fig.canvas.restore_region(bg0) + fig.canvas.restore_region(bg1) + + if phase == "ellipse": + sc0.set_offsets(np.c_[state["xs"], state["ys"]] if state["xs"] else np.empty((0, 2))) + ax0.draw_artist(sc0) + ell_line.set_data(*state["ellipse_cache"] if state["ellipse_cache"] else ([], [])) + ax0.draw_artist(ell_line) + sc1.set_offsets(np.c_[state["xu"], state["yu"]] if state["xu"] else np.empty((0, 2))) + ax1.draw_artist(sc1) + if state["latest_z"]: + zx, zy = state["latest_z"] + radius_line.set_data([0.0, zx], [0.0, zy]) + ang = float(np.arctan2(zy, zx)) + angle_text.set_text( + f"angle: {ang:+.3f} rad ({np.degrees(ang):+.1f}°)\nmove {name}, press 'n' to advance" + ) + else: + radius_line.set_data([], []) + angle_text.set_text("(waiting for fit)") + else: + sc0.set_offsets(np.empty((0, 2))) + ax0.draw_artist(sc0) + ell_line.set_data([], []) + ax0.draw_artist(ell_line) + if state["latest_z"]: + zx, zy = state["latest_z"] + sc1.set_offsets([[zx, zy]]) + radius_line.set_data([0.0, zx], [0.0, zy]) + ang = float(np.arctan2(zy, zx)) + angle_text.set_text( + f"Zero pose for {name}\nangle: {ang:+.3f} rad\nsamples: {len(zero_samples)}/{params.sample_count}\nhold still, press 'n'" + ) + else: + sc1.set_offsets(np.empty((0, 2))) + radius_line.set_data([], []) + angle_text.set_text("(waiting for data)") + ax1.draw_artist(sc1) + + ax1.draw_artist(radius_line) + ax1.draw_artist(angle_text) + fig.canvas.blit(ax0.bbox) + fig.canvas.blit(ax1.bbox) + fig.canvas.flush_events() + last_draw = t + + plt.pause(0.001) + + finally: + plt.close(fig) diff --git a/src/lerobot/teleoperators/unitree_g1/exo_ik.py b/src/lerobot/teleoperators/unitree_g1/exo_ik.py new file mode 100644 index 000000000..92519540f --- /dev/null +++ b/src/lerobot/teleoperators/unitree_g1/exo_ik.py @@ -0,0 +1,353 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +IK helper for exoskeleton-to-G1 teleoperation. We map Exoskeleton joint angles to end-effector pose in world frame, +visualizing the result in meshcat after calibration. +""" + +import logging +import os +from dataclasses import dataclass + +import numpy as np + +from lerobot.robots.unitree_g1.g1_utils import G1_29_JointArmIndex +from lerobot.robots.unitree_g1.robot_kinematic_processor import G1_29_ArmIK + +from .exo_calib import JOINTS + +logger = logging.getLogger(__name__) + + +def _frame_id(model, name: str) -> int | None: + try: + fid = model.getFrameId(name) + return fid if 0 <= fid < model.nframes else None + except Exception: + return None + + +@dataclass +class ArmCfg: + side: str # "left" | "right" + urdf: str # exo_left.urdf / exo_right.urdf + root: str # "exo_left" / "exo_right" + g1_ee: str # "l_ee" / "r_ee" + offset: np.ndarray # world offset for viz + target + marker_prefix: str # "left" / "right" + + +class Markers: + """Creates meshcat visualization primitives, showing end-effector frames of exoskeleton and G1""" + + def __init__(self, viewer): + self.v = viewer + + def sphere(self, path: str, r: float, rgba: tuple[float, float, float, float]): + import meshcat.geometry as mg + + c = (int(rgba[0] * 255) << 16) | (int(rgba[1] * 255) << 8) | int(rgba[2] * 255) + self.v[path].set_object( + mg.Sphere(r), + mg.MeshPhongMaterial(color=c, opacity=rgba[3], transparent=rgba[3] < 1.0), + ) + + def axes(self, path: str, axis_len: float = 0.1, axis_w: int = 6): + import meshcat.geometry as mg + + pts = np.array( + [[0, 0, 0], [axis_len, 0, 0], [0, 0, 0], [0, axis_len, 0], [0, 0, 0], [0, 0, axis_len]], + dtype=np.float32, + ).T + cols = np.array( + [[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 1, 0], [0, 0, 1], [0, 0, 1]], + dtype=np.float32, + ).T + self.v[path].set_object( + mg.LineSegments( + mg.PointsGeometry(position=pts, color=cols), + mg.LineBasicMaterial(linewidth=axis_w, vertexColors=True), + ) + ) + + def tf(self, path: str, mat: np.ndarray): + self.v[path].set_transform(mat) + + +class ExoskeletonIKHelper: + """ + - Loads G1 robot and exoskeleton URDF models via Pinocchio + - Computes forward kinematics on exoskeleton to get end-effector poses + - Solves inverse kinematics on G1 to match those poses + - Provides meshcat visualization showing both robots and targets + + Args: + frozen_joints: List of G1 joint names to exclude from IK (kept at neutral). + """ + + def __init__(self, frozen_joints: list[str] | None = None): + try: + import pinocchio as pin + except ImportError as e: + raise ImportError("ik mode needs pinocchio: pip install pin") from e + + self.pin = pin + self.frozen_joints = frozen_joints or [] + + self.g1_ik = G1_29_ArmIK() + self.robot_g1 = self.g1_ik.reduced_robot + self.robot_g1.data = self.robot_g1.model.createData() + self.q_g1 = pin.neutral(self.robot_g1.model) + + assets_dir = os.path.join(self.g1_ik.repo_path, "assets") + + self.frozen_idx = self._frozen_joint_indices() + + self.arms = [ + ArmCfg( + side="left", + urdf=os.path.join(assets_dir, "exo_left.urdf"), + root="exo_left", + g1_ee="L_ee", + offset=np.array([0.6, 0.3, 0.0]), + marker_prefix="left", + ), + ArmCfg( + side="right", + urdf=os.path.join(assets_dir, "exo_right.urdf"), + root="exo_right", + g1_ee="R_ee", + offset=np.array([0.6, -0.3, 0.0]), + marker_prefix="right", + ), + ] + + self.exo = {} # side -> pin.RobotWrapper + self.q_exo = {} # side -> q + self.ee_id_exo = {} # side -> frame id + self.qmap = {} # side -> {joint_name: q_idx} + self.ee_id_g1 = {} # side -> frame id + + self._load_exo_models(assets_dir) + for a in self.arms: + self.ee_id_g1[a.side] = _frame_id(self.robot_g1.model, a.g1_ee) + + self.viewer = None + self.markers: Markers | None = None + self.viz_g1 = None + self.viz_exo = {} # side -> viz + + def _frozen_joint_indices(self) -> dict[str, int]: + out = {} + m = self.robot_g1.model + for name in self.frozen_joints: + if name in m.names: + jid = m.getJointId(name) + out[name] = m.idx_qs[jid] + logger.info(f"freezing joint: {name} (q_idx={out[name]})") + return out + + def _find_exo_ee(self, model, ee_name: str = "ee") -> int: + ee = _frame_id(model, ee_name) + if ee is not None: + return ee + for fid in reversed(range(model.nframes)): + if model.frames[fid].type == self.pin.FrameType.BODY: + return fid + return 0 + + def _build_joint_map(self, robot) -> dict[str, int]: + m = robot.model + return {n: m.idx_qs[m.getJointId(n)] for n in JOINTS if n in m.names} + + def _load_exo_models(self, assets_dir: str): + pin = self.pin + for a in self.arms: + if not os.path.exists(a.urdf): + logger.warning(f"{a.side} exo urdf not found: {a.urdf}") + continue + r = pin.RobotWrapper.BuildFromURDF(a.urdf, assets_dir) + self.exo[a.side] = r + self.q_exo[a.side] = pin.neutral(r.model) + self.ee_id_exo[a.side] = self._find_exo_ee(r.model) + self.qmap[a.side] = self._build_joint_map(r) + logger.info(f"loaded {a.side} exo urdf: {a.urdf}") + + def init_visualization(self): + """ + Creates a browser-based visualization of exoskeleton and G1 robot, + highlighting end-effector frames and target positions. + """ + try: + from pinocchio.visualize import MeshcatVisualizer + except ImportError as e: + logger.warning(f"meshcat viz unavailable: {e}") + return + + # g1 + self.viz_g1 = MeshcatVisualizer( + self.robot_g1.model, self.robot_g1.collision_model, self.robot_g1.visual_model + ) + self.viz_g1.initViewer(open=True) + self.viz_g1.loadViewerModel("g1") + self.viz_g1.display(self.q_g1) + + self.viewer = self.viz_g1.viewer + self.markers = Markers(self.viewer) + + # exos + for a in self.arms: + if a.side not in self.exo: + continue + r = self.exo[a.side] + v = MeshcatVisualizer(r.model, r.collision_model, r.visual_model) + v.initViewer(open=False) + v.viewer = self.viewer + v.loadViewerModel(a.root) + offset_tf = np.eye(4) + offset_tf[:3, 3] = a.offset + self.viewer[a.root].set_transform(offset_tf) + v.display(self.q_exo[a.side]) + self.viz_exo[a.side] = v + + # markers + for a in self.arms: + p = a.marker_prefix + self.markers.sphere(f"markers/{p}_exo_ee", 0.012, (0.2, 1.0, 0.2, 0.9)) + self.markers.sphere(f"markers/{p}_g1_ee", 0.015, (1.0, 0.2, 0.2, 0.9)) + self.markers.sphere(f"markers/{p}_ik_target", 0.015, (0.1, 0.3, 1.0, 0.9)) + self.markers.axes(f"markers/{p}_exo_axes", 0.06) + self.markers.axes(f"markers/{p}_g1_axes", 0.08) + + logger.info(f"meshcat viz initialized: {self.viewer.url()}") + print(f"\nmeshcat url: {self.viewer.url()}\n") + + def _fk_target_world(self, side: str, angles: dict[str, float]) -> np.ndarray | None: + """returns wrist frame target to be used for G1 IK in 4x4 homogeneous transform. Takes offset into account.""" + if side not in self.exo or not angles: + return None + + pin = self.pin + q = self.q_exo[side] + qmap = self.qmap[side] + + for name, ang in angles.items(): + idx = qmap.get(name) + if idx is not None: + q[idx] = float(ang) + + r = self.exo[side] + pin.forwardKinematics(r.model, r.data, q) + pin.updateFramePlacements(r.model, r.data) + + ee = r.data.oMf[self.ee_id_exo[side]] + target = np.eye(4) + target[:3, :3] = ee.rotation + # offset gets applied in world space + cfg = next(a for a in self.arms if a.side == side) + target[:3, 3] = cfg.offset + ee.translation + return target + + def update_visualization(self): + if self.viewer is None or self.markers is None: + return + + pin = self.pin + + # g1 + if self.viz_g1 is not None: + self.viz_g1.display(self.q_g1) + pin.forwardKinematics(self.robot_g1.model, self.robot_g1.data, self.q_g1) + pin.updateFramePlacements(self.robot_g1.model, self.robot_g1.data) + + for a in self.arms: + fid = self.ee_id_g1.get(a.side) + if fid is None: + continue + ee_tf = self.robot_g1.data.oMf[fid].homogeneous + p = a.marker_prefix + self.markers.tf(f"markers/{p}_g1_ee", ee_tf) + self.markers.tf(f"markers/{p}_g1_axes", ee_tf) + + # exos + for a in self.arms: + side = a.side + v = self.viz_exo.get(side) + if v is None: + continue + + v.display(self.q_exo[side]) + r = self.exo[side] + pin.forwardKinematics(r.model, r.data, self.q_exo[side]) + pin.updateFramePlacements(r.model, r.data) + + ee = r.data.oMf[self.ee_id_exo[side]] + world_tf = (pin.SE3(np.eye(3), a.offset) * ee).homogeneous + p = a.marker_prefix + self.markers.tf(f"markers/{p}_exo_ee", world_tf) + self.markers.tf(f"markers/{p}_exo_axes", world_tf) + + target_tf = np.eye(4) + target_tf[:3, :3] = ee.rotation + target_tf[:3, 3] = a.offset + ee.translation + self.markers.tf(f"markers/{p}_ik_target", target_tf) + + def compute_g1_joints_from_exo( + self, + left_angles: dict[str, float], + right_angles: dict[str, float], + ) -> dict[str, float]: + """ + Performs FK on exoskeleton to get end-effector poses in world frame, + after which it solves IK on G1 to return joint angles matching those poses in G1 motor order. + """ + pin = self.pin + + targets = { + "left": self._fk_target_world("left", left_angles), + "right": self._fk_target_world("right", right_angles), + } + + # fallback to current g1 ee pose if missing target + pin.forwardKinematics(self.robot_g1.model, self.robot_g1.data, self.q_g1) + pin.updateFramePlacements(self.robot_g1.model, self.robot_g1.data) + + for a in self.arms: + if targets[a.side] is not None: + continue + fid = self.ee_id_g1.get(a.side) + if fid is not None: + targets[a.side] = self.robot_g1.data.oMf[fid].homogeneous + + if targets["left"] is None or targets["right"] is None: + logger.warning("missing ik targets, returning current pose") + return {} + + frozen_vals = {n: self.q_g1[i] for n, i in self.frozen_idx.items()} + + self.q_g1, _ = self.g1_ik.solve_ik( + targets["left"], targets["right"], current_lr_arm_motor_q=self.q_g1 + ) + + for n, i in self.frozen_idx.items(): + self.q_g1[i] = frozen_vals[n] + + return { + f"{j.name}.q": float(self.q_g1[i]) + for i, j in enumerate(G1_29_JointArmIndex) + if i < len(self.q_g1) + } diff --git a/src/lerobot/teleoperators/unitree_g1/exo_serial.py b/src/lerobot/teleoperators/unitree_g1/exo_serial.py new file mode 100644 index 000000000..1211c57cc --- /dev/null +++ b/src/lerobot/teleoperators/unitree_g1/exo_serial.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +from dataclasses import dataclass +from pathlib import Path + +import serial + +from .exo_calib import ExoskeletonCalibration, exo_raw_to_angles, run_exo_calibration + +logger = logging.getLogger(__name__) + + +def parse_raw16(line: bytes) -> list[int] | None: + try: + parts = line.decode("utf-8", errors="ignore").split() + if len(parts) < 16: + return None + return [int(x) for x in parts[:16]] + except Exception: + return None + + +def read_raw_from_serial(ser) -> list[int] | None: + """Read latest sample from serial; if buffer is backed up, keep only the newest.""" + last = None + while ser.in_waiting > 0: + b = ser.readline() + if not b: + break + raw16 = parse_raw16(b) + if raw16 is not None: + last = raw16 + if last is None: + b = ser.readline() + if b: + last = parse_raw16(b) + return last + + +@dataclass +class ExoskeletonArm: + port: str + calibration_fpath: Path + side: str + baud_rate: int = 115200 + + _ser: serial.Serial | None = None + calibration: ExoskeletonCalibration | None = None + + def __post_init__(self): + if self.calibration_fpath.is_file(): + self._load_calibration() + + @property + def is_connected(self) -> bool: + return self._ser is not None and getattr(self._ser, "is_open", False) + + @property + def is_calibrated(self) -> bool: + return self.calibration is not None + + def connect(self, calibrate: bool = True) -> None: + if self.is_connected: + return + try: + self._ser = serial.Serial(self.port, self.baud_rate, timeout=0.02) + self._ser.reset_input_buffer() + logger.info(f"connected: {self.port}") + except serial.SerialException as e: + raise ConnectionError(f"failed to connect to {self.port}: {e}") from e + + if calibrate and not self.is_calibrated: + self.calibrate() + + def disconnect(self) -> None: + if self._ser: + try: + self._ser.close() + finally: + self._ser = None + + def _load_calibration(self) -> None: + try: + data = json.loads(self.calibration_fpath.read_text()) + self.calibration = ExoskeletonCalibration.from_dict(data) + logger.info(f"loaded calibration: {self.calibration_fpath}") + except Exception as e: + logger.warning(f"failed to load calibration: {e}") + + def read_raw(self) -> list[int] | None: + if not self._ser: + return None + return read_raw_from_serial(self._ser) + + def get_angles(self) -> dict[str, float]: + if not self.calibration: + raise RuntimeError("exoskeleton not calibrated") + raw = self.read_raw() + return {} if raw is None else exo_raw_to_angles(raw, self.calibration) + + def calibrate(self) -> None: + ser = self._ser + self.calibration = run_exo_calibration(ser, self.side, self.calibration_fpath) diff --git a/src/lerobot/teleoperators/unitree_g1/unitree_g1.py b/src/lerobot/teleoperators/unitree_g1/unitree_g1.py new file mode 100644 index 000000000..3779d83ec --- /dev/null +++ b/src/lerobot/teleoperators/unitree_g1/unitree_g1.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from functools import cached_property + +from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex +from lerobot.utils.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS + +from ..teleoperator import Teleoperator +from .config_unitree_g1 import UnitreeG1TeleoperatorConfig +from .exo_ik import ExoskeletonIKHelper +from .exo_serial import ExoskeletonArm + +logger = logging.getLogger(__name__) + + +class UnitreeG1Teleoperator(Teleoperator): + """ + Bimanual exoskeleton arms teleoperator for Unitree G1 arms. + + Uses inverse kinematics: exoskeleton FK computes end-effector pose, + G1 IK solves for joint angles. + """ + + config_class = UnitreeG1TeleoperatorConfig + name = "unitree_g1" + + def __init__(self, config: UnitreeG1TeleoperatorConfig): + super().__init__(config) + self.config = config + + # Setup calibration directory + self.calibration_dir = ( + config.calibration_dir + if config.calibration_dir + else HF_LEROBOT_CALIBRATION / TELEOPERATORS / self.name + ) + self.calibration_dir.mkdir(parents=True, exist_ok=True) + + left_id = f"{config.id}_left" if config.id else "left" + right_id = f"{config.id}_right" if config.id else "right" + + # Create exoskeleton arm instances + self.left_arm = ExoskeletonArm( + port=config.left_arm_config.port, + baud_rate=config.left_arm_config.baud_rate, + calibration_fpath=self.calibration_dir / f"{left_id}.json", + side="left", + ) + self.right_arm = ExoskeletonArm( + port=config.right_arm_config.port, + baud_rate=config.right_arm_config.baud_rate, + calibration_fpath=self.calibration_dir / f"{right_id}.json", + side="right", + ) + + self.ik_helper: ExoskeletonIKHelper | None = None + + @cached_property + def action_features(self) -> dict[str, type]: + return {f"{name}.q": float for name in self._g1_joint_names} + + @cached_property + def feedback_features(self) -> dict[str, type]: + return {} + + @property + def is_connected(self) -> bool: + return self.left_arm.is_connected and self.right_arm.is_connected + + @property + def is_calibrated(self) -> bool: + return self.left_arm.is_calibrated and self.right_arm.is_calibrated + + def connect(self, calibrate: bool = True) -> None: + self.left_arm.connect(calibrate) + self.right_arm.connect(calibrate) + + frozen_joints = [j.strip() for j in self.config.frozen_joints.split(",") if j.strip()] + self.ik_helper = ExoskeletonIKHelper(frozen_joints=frozen_joints) + logger.info("IK helper initialized") + + def calibrate(self) -> None: + if not self.left_arm.is_calibrated: + logger.info("Starting calibration for left arm...") + self.left_arm.calibrate() + else: + logger.info("Left arm already calibrated. Skipping.") + + if not self.right_arm.is_calibrated: + logger.info("Starting calibration for right arm...") + self.right_arm.calibrate() + else: + logger.info("Right arm already calibrated. Skipping.") + + logger.info("Starting visualization to verify calibration...") + self.run_visualization_loop() + + def configure(self) -> None: + pass + + def get_action(self) -> dict[str, float]: + left_angles = self.left_arm.get_angles() + right_angles = self.right_arm.get_angles() + return self.ik_helper.compute_g1_joints_from_exo(left_angles, right_angles) + + def send_feedback(self, feedback: dict[str, float]) -> None: + raise NotImplementedError("Exoskeleton arms do not support feedback") + + def disconnect(self) -> None: + self.left_arm.disconnect() + self.right_arm.disconnect() + + def run_visualization_loop(self): + """Run interactive Meshcat visualization loop to verify tracking.""" + if self.ik_helper is None: + frozen_joints = [j.strip() for j in self.config.frozen_joints.split(",") if j.strip()] + self.ik_helper = ExoskeletonIKHelper(frozen_joints=frozen_joints) + + self.ik_helper.init_visualization() + + print("\n" + "=" * 60) + print("Visualization running! Move the exoskeletons to test tracking.") + print("Press Ctrl+C to exit.") + print("=" * 60 + "\n") + + try: + while True: + left_angles = self.left_arm.get_angles() + right_angles = self.right_arm.get_angles() + + self.ik_helper.compute_g1_joints_from_exo(left_angles, right_angles) + self.ik_helper.update_visualization() + + time.sleep(0.01) + + except KeyboardInterrupt: + print("\n\nVisualization stopped.") + + @cached_property + def _g1_joint_names(self) -> list[str]: + return [joint.name for joint in G1_29_JointIndex] diff --git a/src/lerobot/teleoperators/utils.py b/src/lerobot/teleoperators/utils.py index 8f6bbc787..3b42d294e 100644 --- a/src/lerobot/teleoperators/utils.py +++ b/src/lerobot/teleoperators/utils.py @@ -75,6 +75,10 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator": from .homunculus import HomunculusArm return HomunculusArm(config) + elif config.type == "unitree_g1": + from .unitree_g1 import UnitreeG1Teleoperator + + return UnitreeG1Teleoperator(config) elif config.type == "bi_so_leader": from .bi_so_leader import BiSOLeader From 4483184875e0c283066ba304e2404638f8aa803e Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 28 Jan 2026 17:25:57 +0100 Subject: [PATCH 069/111] feat(robots): add bi manual openarm follower and leader (#2835) * fix(motors): cleanup imports + fix signatures * feat(motors): add damiao canbus + multiple fixes * fix(motors): address comments -> last_state + different gains + sleep * refactor(motors): reduce duplicated code + adressed some comments in the PR * chore(motors): better timeouts * tests(motors): damiao test and imports * chore(deps): fix space * feat(robot): add openarm leader Co-authored-by: Pepijn * feat(robot): add openarm follower Co-authored-by: Pepijn * refactor(robot): remove mechanical compensations and double arm assumption + rename * chore(robots): remove left arm references * refactor(teleop): multiple improvements to leader * refactor(teleop): multiple improvements to leader * feat(robots): add open arm to util CLI * chore(robot): add alias openarm * Apply suggestions from code review Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> Signed-off-by: Steven Palma * chore(motors): remove normalization tables damiao * fix(motors): imports and signatures * feat(motors): add motor_type_str + recv_id to motor class and _get_motor_recv_id raises if no motor_obj.recv_id * chore(motors): remove normalize from base motor class and damaio * tests(motors): remove bad tests (to be replaced) * chore(motors): updated import check * fix(robots): open arm mirrored config for joint limits * chore(motors): update position_kd gain values * chore(robots): set to 0 if openarm is calibrated at connect time * chore(robots): remove macos in open arm as can doesn't support it * chore(robots): update for motor_type_str in Motor class * chore(robots): no default value for can port in open arms * feat(robots): add bi manual openarm follower and leader * use constant for kp and kd range and check responses in mit_control_batch() * Add docs on setting up canbus and use damiao otor bus, also add lerobot_setup_can.py and log if there is not response from a write command * precommit format * supress bandit as these are intentional cli commands * fix setup-can * add test * skip test in ci * nit precommit * update doc example * dont import can for tests * remove comment * Add openarms docs * format * update purchase link * can to none if nit availabl;e * add canfd option in bus * make handshake logic similar to lerobot-can * type hint * type check * add temp teleop test * remove script * mock class * mock class * ignore linter * pre-commit * Add command for bimanual openarm * fix import * fix import leader * fix import draccus --------- Signed-off-by: Steven Palma Co-authored-by: Pepijn Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com> --- docs/source/openarm.mdx | 18 ++ .../robots/bi_openarm_follower/__init__.py | 20 ++ .../bi_openarm_follower.py | 175 ++++++++++++++++++ .../config_bi_openarm_follower.py | 30 +++ .../robots/openarm_follower/__init__.py | 4 +- .../config_openarm_follower.py | 11 +- src/lerobot/robots/utils.py | 4 + src/lerobot/scripts/lerobot_calibrate.py | 2 + .../scripts/lerobot_find_joint_limits.py | 2 + src/lerobot/scripts/lerobot_record.py | 2 + src/lerobot/scripts/lerobot_replay.py | 1 + src/lerobot/scripts/lerobot_teleoperate.py | 2 + .../bi_openarm_leader/__init__.py | 20 ++ .../bi_openarm_leader/bi_openarm_leader.py | 131 +++++++++++++ .../config_bi_openarm_leader.py | 30 +++ .../teleoperators/openarm_leader/__init__.py | 4 +- .../openarm_leader/config_openarm_leader.py | 11 +- src/lerobot/teleoperators/utils.py | 4 + 18 files changed, 461 insertions(+), 10 deletions(-) create mode 100644 src/lerobot/robots/bi_openarm_follower/__init__.py create mode 100644 src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py create mode 100644 src/lerobot/robots/bi_openarm_follower/config_bi_openarm_follower.py create mode 100644 src/lerobot/teleoperators/bi_openarm_leader/__init__.py create mode 100644 src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py create mode 100644 src/lerobot/teleoperators/bi_openarm_leader/config_bi_openarm_leader.py diff --git a/docs/source/openarm.mdx b/docs/source/openarm.mdx index 661808749..cd4ace912 100644 --- a/docs/source/openarm.mdx +++ b/docs/source/openarm.mdx @@ -174,6 +174,24 @@ lerobot-teleoperate \ --teleop.id=my_leader ``` +### Bimanual Teleoperation + +To teleoperate a bimanual OpenArm setup with two leader and two follower arms: + +```bash +lerobot-teleoperate \ + --robot.type=bi_openarm_follower \ + --robot.left_arm_config.port=can0 \ + --robot.left_arm_config.side=left \ + --robot.right_arm_config.port=can1 \ + --robot.right_arm_config.side=right \ + --robot.id=my_bimanual_follower \ + --teleop.type=bi_openarm_leader \ + --teleop.left_arm_config.port=can2 \ + --teleop.right_arm_config.port=can3 \ + --teleop.id=my_bimanual_leader +``` + ### Recording Data To record a dataset during teleoperation: diff --git a/src/lerobot/robots/bi_openarm_follower/__init__.py b/src/lerobot/robots/bi_openarm_follower/__init__.py new file mode 100644 index 000000000..b1dcce431 --- /dev/null +++ b/src/lerobot/robots/bi_openarm_follower/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .bi_openarm_follower import BiOpenArmFollower +from .config_bi_openarm_follower import BiOpenArmFollowerConfig + +__all__ = ["BiOpenArmFollower", "BiOpenArmFollowerConfig"] diff --git a/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py new file mode 100644 index 000000000..466eb07e5 --- /dev/null +++ b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py @@ -0,0 +1,175 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from functools import cached_property + +from lerobot.processor import RobotAction, RobotObservation +from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig + +from ..robot import Robot +from .config_bi_openarm_follower import BiOpenArmFollowerConfig + +logger = logging.getLogger(__name__) + + +class BiOpenArmFollower(Robot): + """ + Bimanual OpenArm Follower Arms + """ + + config_class = BiOpenArmFollowerConfig + name = "bi_openarm_follower" + + def __init__(self, config: BiOpenArmFollowerConfig): + super().__init__(config) + self.config = config + + left_arm_config = OpenArmFollowerConfig( + id=f"{config.id}_left" if config.id else None, + calibration_dir=config.calibration_dir, + port=config.left_arm_config.port, + disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect, + max_relative_target=config.left_arm_config.max_relative_target, + cameras=config.left_arm_config.cameras, + side=config.left_arm_config.side, + can_interface=config.left_arm_config.can_interface, + use_can_fd=config.left_arm_config.use_can_fd, + can_bitrate=config.left_arm_config.can_bitrate, + can_data_bitrate=config.left_arm_config.can_data_bitrate, + motor_config=config.left_arm_config.motor_config, + position_kd=config.left_arm_config.position_kd, + position_kp=config.left_arm_config.position_kp, + joint_limits=config.left_arm_config.joint_limits, + ) + + right_arm_config = OpenArmFollowerConfig( + id=f"{config.id}_right" if config.id else None, + calibration_dir=config.calibration_dir, + port=config.right_arm_config.port, + disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect, + max_relative_target=config.right_arm_config.max_relative_target, + cameras=config.right_arm_config.cameras, + side=config.right_arm_config.side, + can_interface=config.right_arm_config.can_interface, + use_can_fd=config.right_arm_config.use_can_fd, + can_bitrate=config.right_arm_config.can_bitrate, + can_data_bitrate=config.right_arm_config.can_data_bitrate, + motor_config=config.right_arm_config.motor_config, + position_kd=config.right_arm_config.position_kd, + position_kp=config.right_arm_config.position_kp, + joint_limits=config.right_arm_config.joint_limits, + ) + + self.left_arm = OpenArmFollower(left_arm_config) + self.right_arm = OpenArmFollower(right_arm_config) + + # Only for compatibility with other parts of the codebase that expect a `robot.cameras` attribute + self.cameras = {**self.left_arm.cameras, **self.right_arm.cameras} + + @property + def _motors_ft(self) -> dict[str, type]: + left_arm_motors_ft = self.left_arm._motors_ft + right_arm_motors_ft = self.right_arm._motors_ft + + return { + **{f"left_{k}": v for k, v in left_arm_motors_ft.items()}, + **{f"right_{k}": v for k, v in right_arm_motors_ft.items()}, + } + + @property + def _cameras_ft(self) -> dict[str, tuple]: + left_arm_cameras_ft = self.left_arm._cameras_ft + right_arm_cameras_ft = self.right_arm._cameras_ft + + return { + **{f"left_{k}": v for k, v in left_arm_cameras_ft.items()}, + **{f"right_{k}": v for k, v in right_arm_cameras_ft.items()}, + } + + @cached_property + def observation_features(self) -> dict[str, type | tuple]: + return {**self._motors_ft, **self._cameras_ft} + + @cached_property + def action_features(self) -> dict[str, type]: + return self._motors_ft + + @property + def is_connected(self) -> bool: + return self.left_arm.is_connected and self.right_arm.is_connected + + def connect(self, calibrate: bool = True) -> None: + self.left_arm.connect(calibrate) + self.right_arm.connect(calibrate) + + @property + def is_calibrated(self) -> bool: + return self.left_arm.is_calibrated and self.right_arm.is_calibrated + + def calibrate(self) -> None: + self.left_arm.calibrate() + self.right_arm.calibrate() + + def configure(self) -> None: + self.left_arm.configure() + self.right_arm.configure() + + def setup_motors(self) -> None: + raise NotImplementedError( + "Motor ID configuration is typically done via manufacturer tools for CAN motors." + ) + + def get_observation(self) -> RobotObservation: + obs_dict = {} + + # Add "left_" prefix + left_obs = self.left_arm.get_observation() + obs_dict.update({f"left_{key}": value for key, value in left_obs.items()}) + + # Add "right_" prefix + right_obs = self.right_arm.get_observation() + obs_dict.update({f"right_{key}": value for key, value in right_obs.items()}) + + return obs_dict + + def send_action( + self, + action: RobotAction, + custom_kp: dict[str, float] | None = None, + custom_kd: dict[str, float] | None = None, + ) -> RobotAction: + # Remove "left_" prefix + left_action = { + key.removeprefix("left_"): value for key, value in action.items() if key.startswith("left_") + } + # Remove "right_" prefix + right_action = { + key.removeprefix("right_"): value for key, value in action.items() if key.startswith("right_") + } + + sent_action_left = self.left_arm.send_action(left_action, custom_kp, custom_kd) + sent_action_right = self.right_arm.send_action(right_action, custom_kp, custom_kd) + + # Add prefixes back + prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()} + prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()} + + return {**prefixed_sent_action_left, **prefixed_sent_action_right} + + def disconnect(self): + self.left_arm.disconnect() + self.right_arm.disconnect() diff --git a/src/lerobot/robots/bi_openarm_follower/config_bi_openarm_follower.py b/src/lerobot/robots/bi_openarm_follower/config_bi_openarm_follower.py new file mode 100644 index 000000000..9d11f7b4e --- /dev/null +++ b/src/lerobot/robots/bi_openarm_follower/config_bi_openarm_follower.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from lerobot.robots.openarm_follower import OpenArmFollowerConfigBase + +from ..config import RobotConfig + + +@RobotConfig.register_subclass("bi_openarm_follower") +@dataclass +class BiOpenArmFollowerConfig(RobotConfig): + """Configuration class for Bi OpenArm Follower robots.""" + + left_arm_config: OpenArmFollowerConfigBase + right_arm_config: OpenArmFollowerConfigBase diff --git a/src/lerobot/robots/openarm_follower/__init__.py b/src/lerobot/robots/openarm_follower/__init__.py index 1eb0d9fc7..217432fd5 100644 --- a/src/lerobot/robots/openarm_follower/__init__.py +++ b/src/lerobot/robots/openarm_follower/__init__.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_openarm_follower import OpenArmFollowerConfig +from .config_openarm_follower import OpenArmFollowerConfig, OpenArmFollowerConfigBase from .openarm_follower import OpenArmFollower -__all__ = ["OpenArmFollower", "OpenArmFollowerConfig"] +__all__ = ["OpenArmFollower", "OpenArmFollowerConfig", "OpenArmFollowerConfigBase"] diff --git a/src/lerobot/robots/openarm_follower/config_openarm_follower.py b/src/lerobot/robots/openarm_follower/config_openarm_follower.py index af95b6395..88d81fd50 100644 --- a/src/lerobot/robots/openarm_follower/config_openarm_follower.py +++ b/src/lerobot/robots/openarm_follower/config_openarm_follower.py @@ -43,10 +43,9 @@ RIGHT_DEFAULT_JOINTS_LIMITS: dict[str, tuple[float, float]] = { } -@RobotConfig.register_subclass("openarm_follower") @dataclass -class OpenArmFollowerConfig(RobotConfig): - """Configuration for the OpenArms follower robot with Damiao motors.""" +class OpenArmFollowerConfigBase: + """Base configuration for the OpenArms follower robot with Damiao motors.""" # CAN interfaces - one per arm # arm CAN interface (e.g., "can1") @@ -115,3 +114,9 @@ class OpenArmFollowerConfig(RobotConfig): "gripper": (-5.0, 0.0), } ) + + +@RobotConfig.register_subclass("openarm_follower") +@dataclass +class OpenArmFollowerConfig(RobotConfig, OpenArmFollowerConfigBase): + pass diff --git a/src/lerobot/robots/utils.py b/src/lerobot/robots/utils.py index e0c76cab3..92da597f1 100644 --- a/src/lerobot/robots/utils.py +++ b/src/lerobot/robots/utils.py @@ -64,6 +64,10 @@ def make_robot_from_config(config: RobotConfig) -> Robot: from .openarm_follower import OpenArmFollower return OpenArmFollower(config) + elif config.type == "bi_openarm_follower": + from .bi_openarm_follower import BiOpenArmFollower + + return BiOpenArmFollower(config) elif config.type == "mock_robot": from tests.mocks.mock_robot import MockRobot diff --git a/src/lerobot/scripts/lerobot_calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py index 2fa1b2a03..eb3df6872 100644 --- a/src/lerobot/scripts/lerobot_calibrate.py +++ b/src/lerobot/scripts/lerobot_calibrate.py @@ -36,6 +36,7 @@ from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraCon from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + bi_openarm_follower, bi_so_follower, hope_jr, koch_follower, @@ -48,6 +49,7 @@ from lerobot.robots import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401 Teleoperator, TeleoperatorConfig, + bi_openarm_leader, bi_so_leader, homunculus, koch_leader, diff --git a/src/lerobot/scripts/lerobot_find_joint_limits.py b/src/lerobot/scripts/lerobot_find_joint_limits.py index d928dc5cd..082d11803 100644 --- a/src/lerobot/scripts/lerobot_find_joint_limits.py +++ b/src/lerobot/scripts/lerobot_find_joint_limits.py @@ -44,6 +44,7 @@ import numpy as np from lerobot.model.kinematics import RobotKinematics from lerobot.robots import ( # noqa: F401 RobotConfig, + bi_openarm_follower, bi_so_follower, koch_follower, make_robot_from_config, @@ -53,6 +54,7 @@ from lerobot.robots import ( # noqa: F401 ) from lerobot.teleoperators import ( # noqa: F401 TeleoperatorConfig, + bi_openarm_leader, bi_so_leader, gamepad, koch_leader, diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index d621189e8..0b39e6fff 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -98,6 +98,7 @@ from lerobot.processor.rename_processor import rename_stats from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + bi_openarm_follower, bi_so_follower, earthrover_mini_plus, hope_jr, @@ -112,6 +113,7 @@ from lerobot.robots import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401 Teleoperator, TeleoperatorConfig, + bi_openarm_leader, bi_so_leader, homunculus, koch_leader, diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index c3bc3d766..5717dffb6 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -53,6 +53,7 @@ from lerobot.processor import ( from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + bi_openarm_follower, bi_so_follower, earthrover_mini_plus, hope_jr, diff --git a/src/lerobot/scripts/lerobot_teleoperate.py b/src/lerobot/scripts/lerobot_teleoperate.py index 958bd00ef..b6aa4a750 100644 --- a/src/lerobot/scripts/lerobot_teleoperate.py +++ b/src/lerobot/scripts/lerobot_teleoperate.py @@ -70,6 +70,7 @@ from lerobot.processor import ( from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, + bi_openarm_follower, bi_so_follower, earthrover_mini_plus, hope_jr, @@ -84,6 +85,7 @@ from lerobot.robots import ( # noqa: F401 from lerobot.teleoperators import ( # noqa: F401 Teleoperator, TeleoperatorConfig, + bi_openarm_leader, bi_so_leader, gamepad, homunculus, diff --git a/src/lerobot/teleoperators/bi_openarm_leader/__init__.py b/src/lerobot/teleoperators/bi_openarm_leader/__init__.py new file mode 100644 index 000000000..fe728b826 --- /dev/null +++ b/src/lerobot/teleoperators/bi_openarm_leader/__init__.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .bi_openarm_leader import BiOpenArmLeader +from .config_bi_openarm_leader import BiOpenArmLeaderConfig + +__all__ = ["BiOpenArmLeader", "BiOpenArmLeaderConfig"] diff --git a/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py new file mode 100644 index 000000000..c4383293f --- /dev/null +++ b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from functools import cached_property + +from lerobot.processor import RobotAction +from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfig + +from ..openarm_leader import OpenArmLeader +from ..teleoperator import Teleoperator +from .config_bi_openarm_leader import BiOpenArmLeaderConfig + +logger = logging.getLogger(__name__) + + +class BiOpenArmLeader(Teleoperator): + """ + Bimanual OpenArm Leader Arms + """ + + config_class = BiOpenArmLeaderConfig + name = "bi_openarm_leader" + + def __init__(self, config: BiOpenArmLeaderConfig): + super().__init__(config) + self.config = config + + left_arm_config = OpenArmLeaderConfig( + id=f"{config.id}_left" if config.id else None, + calibration_dir=config.calibration_dir, + port=config.left_arm_config.port, + can_interface=config.left_arm_config.can_interface, + use_can_fd=config.left_arm_config.use_can_fd, + can_bitrate=config.left_arm_config.can_bitrate, + can_data_bitrate=config.left_arm_config.can_data_bitrate, + motor_config=config.left_arm_config.motor_config, + manual_control=config.left_arm_config.manual_control, + position_kd=config.left_arm_config.position_kd, + position_kp=config.left_arm_config.position_kp, + ) + + right_arm_config = OpenArmLeaderConfig( + id=f"{config.id}_right" if config.id else None, + calibration_dir=config.calibration_dir, + port=config.right_arm_config.port, + can_interface=config.right_arm_config.can_interface, + use_can_fd=config.right_arm_config.use_can_fd, + can_bitrate=config.right_arm_config.can_bitrate, + can_data_bitrate=config.right_arm_config.can_data_bitrate, + motor_config=config.right_arm_config.motor_config, + manual_control=config.right_arm_config.manual_control, + position_kd=config.right_arm_config.position_kd, + position_kp=config.right_arm_config.position_kp, + ) + + self.left_arm = OpenArmLeader(left_arm_config) + self.right_arm = OpenArmLeader(right_arm_config) + + @cached_property + def action_features(self) -> dict[str, type]: + left_arm_features = self.left_arm.action_features + right_arm_features = self.right_arm.action_features + + return { + **{f"left_{k}": v for k, v in left_arm_features.items()}, + **{f"right_{k}": v for k, v in right_arm_features.items()}, + } + + @cached_property + def feedback_features(self) -> dict[str, type]: + return {} + + @property + def is_connected(self) -> bool: + return self.left_arm.is_connected and self.right_arm.is_connected + + def connect(self, calibrate: bool = True) -> None: + self.left_arm.connect(calibrate) + self.right_arm.connect(calibrate) + + @property + def is_calibrated(self) -> bool: + return self.left_arm.is_calibrated and self.right_arm.is_calibrated + + def calibrate(self) -> None: + self.left_arm.calibrate() + self.right_arm.calibrate() + + def configure(self) -> None: + self.left_arm.configure() + self.right_arm.configure() + + def setup_motors(self) -> None: + raise NotImplementedError( + "Motor ID configuration is typically done via manufacturer tools for CAN motors." + ) + + def get_action(self) -> RobotAction: + action_dict = {} + + # Add "left_" prefix + left_action = self.left_arm.get_action() + action_dict.update({f"left_{key}": value for key, value in left_action.items()}) + + # Add "right_" prefix + right_action = self.right_arm.get_action() + action_dict.update({f"right_{key}": value for key, value in right_action.items()}) + + return action_dict + + def send_feedback(self, feedback: dict[str, float]) -> None: + # TODO: Implement force feedback + raise NotImplementedError + + def disconnect(self) -> None: + self.left_arm.disconnect() + self.right_arm.disconnect() diff --git a/src/lerobot/teleoperators/bi_openarm_leader/config_bi_openarm_leader.py b/src/lerobot/teleoperators/bi_openarm_leader/config_bi_openarm_leader.py new file mode 100644 index 000000000..39fc90add --- /dev/null +++ b/src/lerobot/teleoperators/bi_openarm_leader/config_bi_openarm_leader.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfigBase + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("bi_openarm_leader") +@dataclass +class BiOpenArmLeaderConfig(TeleoperatorConfig): + """Configuration class for Bi OpenArm Follower robots.""" + + left_arm_config: OpenArmLeaderConfigBase + right_arm_config: OpenArmLeaderConfigBase diff --git a/src/lerobot/teleoperators/openarm_leader/__init__.py b/src/lerobot/teleoperators/openarm_leader/__init__.py index 1493317fe..172cf8228 100644 --- a/src/lerobot/teleoperators/openarm_leader/__init__.py +++ b/src/lerobot/teleoperators/openarm_leader/__init__.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .config_openarm_leader import OpenArmLeaderConfig +from .config_openarm_leader import OpenArmLeaderConfig, OpenArmLeaderConfigBase from .openarm_leader import OpenArmLeader -__all__ = ["OpenArmLeader", "OpenArmLeaderConfig"] +__all__ = ["OpenArmLeader", "OpenArmLeaderConfig", "OpenArmLeaderConfigBase"] diff --git a/src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py b/src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py index c53169b0a..4b12fe730 100644 --- a/src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py +++ b/src/lerobot/teleoperators/openarm_leader/config_openarm_leader.py @@ -19,10 +19,9 @@ from dataclasses import dataclass, field from ..config import TeleoperatorConfig -@TeleoperatorConfig.register_subclass("openarm_leader") @dataclass -class OpenArmLeaderConfig(TeleoperatorConfig): - """Configuration for the OpenArms leader/teleoperator with Damiao motors.""" +class OpenArmLeaderConfigBase: + """Base configuration for the OpenArms leader/teleoperator with Damiao motors.""" # CAN interfaces - one per arm # Arm CAN interface (e.g., "can3") @@ -68,3 +67,9 @@ class OpenArmLeaderConfig(TeleoperatorConfig): default_factory=lambda: [240.0, 240.0, 240.0, 240.0, 24.0, 31.0, 25.0, 16.0] ) position_kd: list[float] = field(default_factory=lambda: [3.0, 3.0, 3.0, 3.0, 0.2, 0.2, 0.2, 0.2]) + + +@TeleoperatorConfig.register_subclass("openarm_leader") +@dataclass +class OpenArmLeaderConfig(TeleoperatorConfig, OpenArmLeaderConfigBase): + pass diff --git a/src/lerobot/teleoperators/utils.py b/src/lerobot/teleoperators/utils.py index 3b42d294e..16454d5ad 100644 --- a/src/lerobot/teleoperators/utils.py +++ b/src/lerobot/teleoperators/utils.py @@ -91,6 +91,10 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator": from .openarm_leader import OpenArmLeader return OpenArmLeader(config) + elif config.type == "bi_openarm_leader": + from .bi_openarm_leader import BiOpenArmLeader + + return BiOpenArmLeader(config) else: try: return cast("Teleoperator", make_device_from_device_class(config)) From 3409ef0dc2fc948468ec573a1e47f8ab7747cee2 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Thu, 29 Jan 2026 04:07:47 -0600 Subject: [PATCH 070/111] refactor(cameras): cameras API extension (#2808) * feat(cameras): add new read_latest() method * fix(cameras): fix threading bug + clear state * refactor(cameras): multiple improvements * feat(camera): add context manager to camera base class * chore(camera): slight modifications to opencv * test(cameras): update opencv tests according to the changes * refactor(cameras): reflect desing changes to realsense + deal with depth * test(cameras): fix realsense tests accordingly to new changes * refactor(cameras): update reachymini and zmq accordingly * chore: wrap resource sensitive examples into a try/finally * test(cameras): add test for new read_latest * test(cameras): fix problem with image artifact in opencv tests * test(cameras): fix test_read_latest_high_frequency expectations * Apply suggestions from code review 1 Co-authored-by: Caroline Pascal Signed-off-by: Steven Palma * chore(cameras): address feedback * feat(cameras): add max_age_ms check in read_latest * test(cameras): fix read_latest tests * chore(redundancies): removing redundancies in Reachy 2 camera class * fix(warmup): replacing the arbitrary time.sleep in by an actual warmup in the RealSense camera class * chore(format): formatting latest changes * chore(warning): adding a "to be implemented" warning for read_latest() in Camera base class * chore(warning): making read_latest() warning message shorter and clearer --------- Signed-off-by: Steven Palma Co-authored-by: Caroline Pascal --- examples/backward_compatibility/replay.py | 31 +-- examples/lekiwi/evaluate.py | 88 +++---- examples/lekiwi/record.py | 89 +++---- examples/lekiwi/replay.py | 32 +-- examples/phone_to_so100/evaluate.py | 85 ++++--- examples/phone_to_so100/record.py | 85 ++++--- examples/phone_to_so100/replay.py | 42 ++-- examples/so100_to_so100_EE/evaluate.py | 85 ++++--- examples/so100_to_so100_EE/record.py | 86 ++++--- examples/so100_to_so100_EE/replay.py | 41 +-- src/lerobot/cameras/camera.py | 100 ++++++-- src/lerobot/cameras/opencv/camera_opencv.py | 166 ++++++++---- .../cameras/reachy2_camera/reachy2_camera.py | 67 +++-- .../cameras/realsense/camera_realsense.py | 212 +++++++++++----- src/lerobot/cameras/zmq/camera_zmq.py | 238 ++++++++++++++---- src/lerobot/cameras/zmq/configuration_zmq.py | 1 + src/lerobot/scripts/lerobot_calibrate.py | 7 +- src/lerobot/scripts/lerobot_replay.py | 29 +-- tests/cameras/test_opencv.py | 184 +++++++++----- tests/cameras/test_reachy2_camera.py | 38 +++ tests/cameras/test_realsense.py | 114 +++++---- 21 files changed, 1179 insertions(+), 641 deletions(-) diff --git a/examples/backward_compatibility/replay.py b/examples/backward_compatibility/replay.py index ed78d016f..8de5ba197 100644 --- a/examples/backward_compatibility/replay.py +++ b/examples/backward_compatibility/replay.py @@ -81,24 +81,25 @@ def replay(cfg: ReplayConfig): actions = dataset.hf_dataset.select_columns(ACTION) robot.connect() - log_say("Replaying episode", cfg.play_sounds, blocking=True) - for idx in range(dataset.num_frames): - start_episode_t = time.perf_counter() + try: + log_say("Replaying episode", cfg.play_sounds, blocking=True) + for idx in range(dataset.num_frames): + start_episode_t = time.perf_counter() - action_array = actions[idx][ACTION] - action = {} - for i, name in enumerate(dataset.features[ACTION]["names"]): - key = f"{name.removeprefix('main_')}.pos" - action[key] = action_array[i].item() + action_array = actions[idx][ACTION] + action = {} + for i, name in enumerate(dataset.features[ACTION]["names"]): + key = f"{name.removeprefix('main_')}.pos" + action[key] = action_array[i].item() - action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90) - action["elbow_flex.pos"] -= 90 - robot.send_action(action) + action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90) + action["elbow_flex.pos"] -= 90 + robot.send_action(action) - dt_s = time.perf_counter() - start_episode_t - precise_sleep(max(1 / dataset.fps - dt_s, 0.0)) - - robot.disconnect() + dt_s = time.perf_counter() - start_episode_t + precise_sleep(max(1 / dataset.fps - dt_s, 0.0)) + finally: + robot.disconnect() if __name__ == "__main__": diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py index 2f7f9f95f..a3144a442 100644 --- a/examples/lekiwi/evaluate.py +++ b/examples/lekiwi/evaluate.py @@ -78,40 +78,24 @@ def main(): listener, events = init_keyboard_listener() init_rerun(session_name="lekiwi_evaluate") - if not robot.is_connected: - raise ValueError("Robot is not connected!") + try: + if not robot.is_connected: + raise ValueError("Robot is not connected!") - print("Starting evaluate loop...") - recorded_episodes = 0 - while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: - log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}") + print("Starting evaluate loop...") + recorded_episodes = 0 + while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: + log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}") - # Main record loop - record_loop( - robot=robot, - events=events, - fps=FPS, - policy=policy, - preprocessor=preprocessor, # Pass the pre and post policy processors - postprocessor=postprocessor, - dataset=dataset, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=teleop_action_processor, - robot_action_processor=robot_action_processor, - robot_observation_processor=robot_observation_processor, - ) - - # Reset the environment if not stopping or re-recording - if not events["stop_recording"] and ( - (recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"] - ): - log_say("Reset the environment") + # Main record loop record_loop( robot=robot, events=events, fps=FPS, + policy=policy, + preprocessor=preprocessor, # Pass the pre and post policy processors + postprocessor=postprocessor, + dataset=dataset, control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, @@ -120,24 +104,42 @@ def main(): robot_observation_processor=robot_observation_processor, ) - if events["rerecord_episode"]: - log_say("Re-record episode") - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ( + (recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, + ) - # Save episode - dataset.save_episode() - recorded_episodes += 1 + if events["rerecord_episode"]: + log_say("Re-record episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue - # Clean up - log_say("Stop recording") - robot.disconnect() - listener.stop() + # Save episode + dataset.save_episode() + recorded_episodes += 1 - dataset.finalize() - dataset.push_to_hub() + finally: + # Clean up + log_say("Stop recording") + robot.disconnect() + listener.stop() + + dataset.finalize() + dataset.push_to_hub() if __name__ == "__main__": diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index 18b9f857e..9292157f7 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -74,40 +74,23 @@ def main(): listener, events = init_keyboard_listener() init_rerun(session_name="lekiwi_record") - if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected: - raise ValueError("Robot or teleop is not connected!") + try: + if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected: + raise ValueError("Robot or teleop is not connected!") - print("Starting record loop...") - recorded_episodes = 0 - while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: - log_say(f"Recording episode {recorded_episodes}") + print("Starting record loop...") + recorded_episodes = 0 + while recorded_episodes < NUM_EPISODES and not events["stop_recording"]: + log_say(f"Recording episode {recorded_episodes}") - # Main record loop - record_loop( - robot=robot, - events=events, - fps=FPS, - dataset=dataset, - teleop=[leader_arm, keyboard], - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=teleop_action_processor, - robot_action_processor=robot_action_processor, - robot_observation_processor=robot_observation_processor, - ) - - # Reset the environment if not stopping or re-recording - if not events["stop_recording"] and ( - (recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"] - ): - log_say("Reset the environment") + # Main record loop record_loop( robot=robot, events=events, fps=FPS, + dataset=dataset, teleop=[leader_arm, keyboard], - control_time_s=RESET_TIME_SEC, + control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, teleop_action_processor=teleop_action_processor, @@ -115,26 +98,44 @@ def main(): robot_observation_processor=robot_observation_processor, ) - if events["rerecord_episode"]: - log_say("Re-record episode") - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ( + (recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + teleop=[leader_arm, keyboard], + control_time_s=RESET_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, + ) - # Save episode - dataset.save_episode() - recorded_episodes += 1 + if events["rerecord_episode"]: + log_say("Re-record episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue - # Clean up - log_say("Stop recording") - robot.disconnect() - leader_arm.disconnect() - keyboard.disconnect() - listener.stop() + # Save episode + dataset.save_episode() + recorded_episodes += 1 + finally: + # Clean up + log_say("Stop recording") + robot.disconnect() + leader_arm.disconnect() + keyboard.disconnect() + listener.stop() - dataset.finalize() - dataset.push_to_hub() + dataset.finalize() + dataset.push_to_hub() if __name__ == "__main__": diff --git a/examples/lekiwi/replay.py b/examples/lekiwi/replay.py index 872dacf27..cf89aea16 100644 --- a/examples/lekiwi/replay.py +++ b/examples/lekiwi/replay.py @@ -42,25 +42,27 @@ def main(): # Connect to the robot robot.connect() - if not robot.is_connected: - raise ValueError("Robot is not connected!") + try: + if not robot.is_connected: + raise ValueError("Robot is not connected!") - print("Starting replay loop...") - log_say(f"Replaying episode {EPISODE_IDX}") - for idx in range(len(episode_frames)): - t0 = time.perf_counter() + print("Starting replay loop...") + log_say(f"Replaying episode {EPISODE_IDX}") + for idx in range(len(episode_frames)): + t0 = time.perf_counter() - # Get recorded action from dataset - action = { - name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"]) - } + # Get recorded action from dataset + action = { + name: float(actions[idx][ACTION][i]) + for i, name in enumerate(dataset.features[ACTION]["names"]) + } - # Send action to robot - _ = robot.send_action(action) + # Send action to robot + _ = robot.send_action(action) - precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) - - robot.disconnect() + precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) + finally: + robot.disconnect() if __name__ == "__main__": diff --git a/examples/phone_to_so100/evaluate.py b/examples/phone_to_so100/evaluate.py index 246c923aa..837217eda 100644 --- a/examples/phone_to_so100/evaluate.py +++ b/examples/phone_to_so100/evaluate.py @@ -142,38 +142,24 @@ def main(): listener, events = init_keyboard_listener() init_rerun(session_name="phone_so100_evaluate") - if not robot.is_connected: - raise ValueError("Robot is not connected!") + try: + if not robot.is_connected: + raise ValueError("Robot is not connected!") - print("Starting evaluate loop...") - episode_idx = 0 - for episode_idx in range(NUM_EPISODES): - log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") + print("Starting evaluate loop...") + episode_idx = 0 + for episode_idx in range(NUM_EPISODES): + log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") - # Main record loop - record_loop( - robot=robot, - events=events, - fps=FPS, - policy=policy, - preprocessor=preprocessor, # Pass the pre and post policy processors - postprocessor=postprocessor, - dataset=dataset, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=make_default_teleop_action_processor(), - robot_action_processor=robot_ee_to_joints_processor, - robot_observation_processor=robot_joints_to_ee_pose_processor, - ) - - # Reset the environment if not stopping or re-recording - if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]): - log_say("Reset the environment") + # Main record loop record_loop( robot=robot, events=events, fps=FPS, + policy=policy, + preprocessor=preprocessor, # Pass the pre and post policy processors + postprocessor=postprocessor, + dataset=dataset, control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, @@ -182,24 +168,41 @@ def main(): robot_observation_processor=robot_joints_to_ee_pose_processor, ) - if events["rerecord_episode"]: - log_say("Re-record episode") - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ( + (episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=make_default_teleop_action_processor(), + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose_processor, + ) - # Save episode - dataset.save_episode() - episode_idx += 1 + if events["rerecord_episode"]: + log_say("Re-record episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue - # Clean up - log_say("Stop recording") - robot.disconnect() - listener.stop() + # Save episode + dataset.save_episode() + episode_idx += 1 + finally: + # Clean up + log_say("Stop recording") + robot.disconnect() + listener.stop() - dataset.finalize() - dataset.push_to_hub() + dataset.finalize() + dataset.push_to_hub() if __name__ == "__main__": diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py index 7b5b704e2..1f5005db9 100644 --- a/examples/phone_to_so100/record.py +++ b/examples/phone_to_so100/record.py @@ -149,38 +149,23 @@ def main(): listener, events = init_keyboard_listener() init_rerun(session_name="phone_so100_record") - if not robot.is_connected or not phone.is_connected: - raise ValueError("Robot or teleop is not connected!") + try: + if not robot.is_connected or not phone.is_connected: + raise ValueError("Robot or teleop is not connected!") - print("Starting record loop. Move your phone to teleoperate the robot...") - episode_idx = 0 - while episode_idx < NUM_EPISODES and not events["stop_recording"]: - log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") + print("Starting record loop. Move your phone to teleoperate the robot...") + episode_idx = 0 + while episode_idx < NUM_EPISODES and not events["stop_recording"]: + log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") - # Main record loop - record_loop( - robot=robot, - events=events, - fps=FPS, - teleop=phone, - dataset=dataset, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=phone_to_robot_ee_pose_processor, - robot_action_processor=robot_ee_to_joints_processor, - robot_observation_processor=robot_joints_to_ee_pose, - ) - - # Reset the environment if not stopping or re-recording - if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]): - log_say("Reset the environment") + # Main record loop record_loop( robot=robot, events=events, fps=FPS, teleop=phone, - control_time_s=RESET_TIME_SEC, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, teleop_action_processor=phone_to_robot_ee_pose_processor, @@ -188,25 +173,43 @@ def main(): robot_observation_processor=robot_joints_to_ee_pose, ) - if events["rerecord_episode"]: - log_say("Re-recording episode") - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ( + episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"] + ): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + teleop=phone, + control_time_s=RESET_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=phone_to_robot_ee_pose_processor, + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose, + ) - # Save episode - dataset.save_episode() - episode_idx += 1 + if events["rerecord_episode"]: + log_say("Re-recording episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue - # Clean up - log_say("Stop recording") - robot.disconnect() - phone.disconnect() - listener.stop() + # Save episode + dataset.save_episode() + episode_idx += 1 + finally: + # Clean up + log_say("Stop recording") + robot.disconnect() + phone.disconnect() + listener.stop() - dataset.finalize() - dataset.push_to_hub() + dataset.finalize() + dataset.push_to_hub() if __name__ == "__main__": diff --git a/examples/phone_to_so100/replay.py b/examples/phone_to_so100/replay.py index 875025dfc..9d7806cf4 100644 --- a/examples/phone_to_so100/replay.py +++ b/examples/phone_to_so100/replay.py @@ -73,32 +73,34 @@ def main(): # Connect to the robot robot.connect() - if not robot.is_connected: - raise ValueError("Robot is not connected!") + try: + if not robot.is_connected: + raise ValueError("Robot is not connected!") - print("Starting replay loop...") - log_say(f"Replaying episode {EPISODE_IDX}") - for idx in range(len(episode_frames)): - t0 = time.perf_counter() + print("Starting replay loop...") + log_say(f"Replaying episode {EPISODE_IDX}") + for idx in range(len(episode_frames)): + t0 = time.perf_counter() - # Get recorded action from dataset - ee_action = { - name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"]) - } + # Get recorded action from dataset + ee_action = { + name: float(actions[idx][ACTION][i]) + for i, name in enumerate(dataset.features[ACTION]["names"]) + } - # Get robot observation - robot_obs = robot.get_observation() + # Get robot observation + robot_obs = robot.get_observation() - # Dataset EE -> robot joints - joint_action = robot_ee_to_joints_processor((ee_action, robot_obs)) + # Dataset EE -> robot joints + joint_action = robot_ee_to_joints_processor((ee_action, robot_obs)) - # Send action to robot - _ = robot.send_action(joint_action) + # Send action to robot + _ = robot.send_action(joint_action) - precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) - - # Clean up - robot.disconnect() + precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) + finally: + # Clean up + robot.disconnect() if __name__ == "__main__": diff --git a/examples/so100_to_so100_EE/evaluate.py b/examples/so100_to_so100_EE/evaluate.py index 87d188f99..b614b89f2 100644 --- a/examples/so100_to_so100_EE/evaluate.py +++ b/examples/so100_to_so100_EE/evaluate.py @@ -142,38 +142,24 @@ def main(): listener, events = init_keyboard_listener() init_rerun(session_name="so100_so100_evaluate") - if not robot.is_connected: - raise ValueError("Robot is not connected!") + try: + if not robot.is_connected: + raise ValueError("Robot is not connected!") - print("Starting evaluate loop...") - episode_idx = 0 - for episode_idx in range(NUM_EPISODES): - log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") + print("Starting evaluate loop...") + episode_idx = 0 + for episode_idx in range(NUM_EPISODES): + log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") - # Main record loop - record_loop( - robot=robot, - events=events, - fps=FPS, - policy=policy, - preprocessor=preprocessor, # Pass the pre and post policy processors - postprocessor=postprocessor, - dataset=dataset, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=make_default_teleop_action_processor(), - robot_action_processor=robot_ee_to_joints_processor, - robot_observation_processor=robot_joints_to_ee_pose_processor, - ) - - # Reset the environment if not stopping or re-recording - if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]): - log_say("Reset the environment") + # Main record loop record_loop( robot=robot, events=events, fps=FPS, + policy=policy, + preprocessor=preprocessor, # Pass the pre and post policy processors + postprocessor=postprocessor, + dataset=dataset, control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, @@ -182,24 +168,41 @@ def main(): robot_observation_processor=robot_joints_to_ee_pose_processor, ) - if events["rerecord_episode"]: - log_say("Re-record episode") - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ( + (episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=make_default_teleop_action_processor(), + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose_processor, + ) - # Save episode - dataset.save_episode() - episode_idx += 1 + if events["rerecord_episode"]: + log_say("Re-record episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue - # Clean up - log_say("Stop recording") - robot.disconnect() - listener.stop() + # Save episode + dataset.save_episode() + episode_idx += 1 + finally: + # Clean up + log_say("Stop recording") + robot.disconnect() + listener.stop() - dataset.finalize() - dataset.push_to_hub() + dataset.finalize() + dataset.push_to_hub() if __name__ == "__main__": diff --git a/examples/so100_to_so100_EE/record.py b/examples/so100_to_so100_EE/record.py index eead7a9a8..d85a1c5cc 100644 --- a/examples/so100_to_so100_EE/record.py +++ b/examples/so100_to_so100_EE/record.py @@ -146,38 +146,23 @@ def main(): listener, events = init_keyboard_listener() init_rerun(session_name="recording_phone") - if not leader.is_connected or not follower.is_connected: - raise ValueError("Robot or teleop is not connected!") + try: + if not leader.is_connected or not follower.is_connected: + raise ValueError("Robot or teleop is not connected!") - print("Starting record loop...") - episode_idx = 0 - while episode_idx < NUM_EPISODES and not events["stop_recording"]: - log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") + print("Starting record loop...") + episode_idx = 0 + while episode_idx < NUM_EPISODES and not events["stop_recording"]: + log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") - # Main record loop - record_loop( - robot=follower, - events=events, - fps=FPS, - teleop=leader, - dataset=dataset, - control_time_s=EPISODE_TIME_SEC, - single_task=TASK_DESCRIPTION, - display_data=True, - teleop_action_processor=leader_joints_to_ee, - robot_action_processor=ee_to_follower_joints, - robot_observation_processor=follower_joints_to_ee, - ) - - # Reset the environment if not stopping or re-recording - if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]): - log_say("Reset the environment") + # Main record loop record_loop( robot=follower, events=events, fps=FPS, teleop=leader, - control_time_s=RESET_TIME_SEC, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, display_data=True, teleop_action_processor=leader_joints_to_ee, @@ -185,25 +170,44 @@ def main(): robot_observation_processor=follower_joints_to_ee, ) - if events["rerecord_episode"]: - log_say("Re-recording episode") - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and ( + episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"] + ): + log_say("Reset the environment") + record_loop( + robot=follower, + events=events, + fps=FPS, + teleop=leader, + control_time_s=RESET_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + teleop_action_processor=leader_joints_to_ee, + robot_action_processor=ee_to_follower_joints, + robot_observation_processor=follower_joints_to_ee, + ) - # Save episode - dataset.save_episode() - episode_idx += 1 + if events["rerecord_episode"]: + log_say("Re-recording episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue - # Clean up - log_say("Stop recording") - leader.disconnect() - follower.disconnect() - listener.stop() + # Save episode + dataset.save_episode() + episode_idx += 1 - dataset.finalize() - dataset.push_to_hub() + finally: + # Clean up + log_say("Stop recording") + leader.disconnect() + follower.disconnect() + listener.stop() + + dataset.finalize() + dataset.push_to_hub() if __name__ == "__main__": diff --git a/examples/so100_to_so100_EE/replay.py b/examples/so100_to_so100_EE/replay.py index 7d35a7b44..47a2f6635 100644 --- a/examples/so100_to_so100_EE/replay.py +++ b/examples/so100_to_so100_EE/replay.py @@ -74,32 +74,35 @@ def main(): # Connect to the robot robot.connect() - if not robot.is_connected: - raise ValueError("Robot is not connected!") + try: + if not robot.is_connected: + raise ValueError("Robot is not connected!") - print("Starting replay loop...") - log_say(f"Replaying episode {EPISODE_IDX}") - for idx in range(len(episode_frames)): - t0 = time.perf_counter() + print("Starting replay loop...") + log_say(f"Replaying episode {EPISODE_IDX}") + for idx in range(len(episode_frames)): + t0 = time.perf_counter() - # Get recorded action from dataset - ee_action = { - name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"]) - } + # Get recorded action from dataset + ee_action = { + name: float(actions[idx][ACTION][i]) + for i, name in enumerate(dataset.features[ACTION]["names"]) + } - # Get robot observation - robot_obs = robot.get_observation() + # Get robot observation + robot_obs = robot.get_observation() - # Dataset EE -> robot joints - joint_action = robot_ee_to_joints_processor((ee_action, robot_obs)) + # Dataset EE -> robot joints + joint_action = robot_ee_to_joints_processor((ee_action, robot_obs)) - # Send action to robot - _ = robot.send_action(joint_action) + # Send action to robot + _ = robot.send_action(joint_action) - precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) + precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) - # Clean up - robot.disconnect() + finally: + # Clean up + robot.disconnect() if __name__ == "__main__": diff --git a/src/lerobot/cameras/camera.py b/src/lerobot/cameras/camera.py index bfdb571a7..2894e0215 100644 --- a/src/lerobot/cameras/camera.py +++ b/src/lerobot/cameras/camera.py @@ -15,11 +15,12 @@ # limitations under the License. import abc +import warnings from typing import Any from numpy.typing import NDArray # type: ignore # TODO: add type stubs for numpy.typing -from .configs import CameraConfig, ColorMode +from .configs import CameraConfig class Camera(abc.ABC): @@ -30,20 +31,12 @@ class Camera(abc.ABC): Manages basic camera properties (FPS, resolution) and core operations: - Connection/disconnection - - Frame capture (sync/async) + - Frame capture (sync/async/latest) Attributes: fps (int | None): Configured frames per second width (int | None): Frame width in pixels height (int | None): Frame height in pixels - - Example: - class MyCamera(Camera): - def __init__(self, config): ... - @property - def is_connected(self) -> bool: ... - def connect(self, warmup=True): ... - # Plus other required methods """ def __init__(self, config: CameraConfig): @@ -56,6 +49,32 @@ class Camera(abc.ABC): self.width: int | None = config.width self.height: int | None = config.height + def __enter__(self): + """ + Context manager entry. + Automatically connects to the camera. + """ + self.connect() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + """ + Context manager exit. + Automatically disconnects, ensuring resources are released even on error. + """ + self.disconnect() + + def __del__(self) -> None: + """ + Destructor safety net. + Attempts to disconnect if the object is garbage collected without cleanup. + """ + try: + if self.is_connected: + self.disconnect() + except Exception: # nosec B110 + pass + @property @abc.abstractmethod def is_connected(self) -> bool: @@ -89,12 +108,10 @@ class Camera(abc.ABC): pass @abc.abstractmethod - def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: - """Capture and return a single frame from the camera. + def read(self) -> NDArray[Any]: + """Capture and return a single frame from the camera synchronously. - Args: - color_mode: Desired color mode for the output frame. If None, - uses the camera's default color mode. + This is a blocking call that will wait for the hardware and its SDK. Returns: np.ndarray: Captured frame as a numpy array. @@ -103,17 +120,64 @@ class Camera(abc.ABC): @abc.abstractmethod def async_read(self, timeout_ms: float = ...) -> NDArray[Any]: - """Asynchronously capture and return a single frame from the camera. + """Return the most recent new frame. + + This method retrieves the latest frame captured by the background thread. + If a new frame is already available in the buffer (captured since the last call), + it returns it immediately. + + It blocks up to `timeout_ms` only if the buffer is empty or if the latest frame + was already consumed by a previous `async_read` call. + + Essentially, this method return the latest unconsumed frame, waiting if necessary + for a new one to arrive within the specified timeout. + + Usage: + - Ideal for control loops where you want to ensure every processed frame + is fresh, effectively synchronizing your loop to the camera's FPS. + - Causes of a timeout usually include: very low camera FPS, heavy processing load, + or if the camera is disconnected. Args: - timeout_ms: Maximum time to wait for a frame in milliseconds. - Defaults to implementation-specific timeout. + timeout_ms: Maximum time to wait for a new frame in milliseconds. + Defaults to 200ms (0.2s). Returns: np.ndarray: Captured frame as a numpy array. + + Raises: + TimeoutError: If no new frame arrives within `timeout_ms`. """ pass + def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + """Return the most recent frame captured immediately (Peeking). + + This method is non-blocking and returns whatever is currently in the + memory buffer. The frame may be stale, + meaning it could have been captured a while ago (hanging camera scenario e.g.). + + Usage: + Ideal for scenarios requiring zero latency or decoupled frequencies & when + we want a guaranteed frame, such as UI visualization, logging, or + non-critical monitoring. + + Returns: + NDArray[Any]: The frame image (numpy array). + + Raises: + TimeoutError: If the latest frame is older than `max_age_ms`. + NotConnectedError: If the camera is not connected. + RuntimeError: If the camera is connected but has not captured any frames yet. + """ + warnings.warn( + f"{self.__class__.__name__}.read_latest() is not implemented. " + "Please override read_latest(); it will be required in future releases.", + FutureWarning, + stacklevel=2, + ) + return self.async_read() + @abc.abstractmethod def disconnect(self) -> None: """Disconnect from the camera and release resources.""" diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py index b1043ba64..d581e1425 100644 --- a/src/lerobot/cameras/opencv/camera_opencv.py +++ b/src/lerobot/cameras/opencv/camera_opencv.py @@ -70,34 +70,24 @@ class OpenCVCamera(Camera): Example: ```python from lerobot.cameras.opencv import OpenCVCamera - from lerobot.cameras.configuration_opencv import OpenCVCameraConfig, ColorMode, Cv2Rotation + from lerobot.cameras.configuration_opencv import OpenCVCameraConfig # Basic usage with camera index 0 config = OpenCVCameraConfig(index_or_path=0) camera = OpenCVCamera(config) camera.connect() - # Read 1 frame synchronously + # Read 1 frame synchronously (blocking) color_image = camera.read() - print(color_image.shape) - # Read 1 frame asynchronously + # Read 1 frame asynchronously (waits for new frame with a timeout) async_image = camera.async_read() + # Get the latest frame immediately (no wait, returns timestamp) + latest_image, timestamp = camera.read_latest() + # When done, properly disconnect the camera using camera.disconnect() - - # Example with custom settings - custom_config = OpenCVCameraConfig( - index_or_path='/dev/video0', # Or use an index - fps=30, - width=1280, - height=720, - color_mode=ColorMode.RGB, - rotation=Cv2Rotation.ROTATE_90 - ) - custom_camera = OpenCVCamera(custom_config) - # ... connect, read, disconnect ... ``` """ @@ -123,6 +113,7 @@ class OpenCVCamera(Camera): self.stop_event: Event | None = None self.frame_lock: Lock = Lock() self.latest_frame: NDArray[Any] | None = None + self.latest_timestamp: float | None = None self.new_frame_event: Event = Event() self.rotation: int | None = get_cv2_rotation(config.rotation) @@ -146,12 +137,16 @@ class OpenCVCamera(Camera): Connects to the OpenCV camera specified in the configuration. Initializes the OpenCV VideoCapture object, sets desired camera properties - (FPS, width, height), and performs initial checks. + (FPS, width, height), starts the background reading thread and performs initial checks. + + Args: + warmup (bool): If True, waits at connect() time until at least one valid frame + has been captured by the background thread. Defaults to True. Raises: DeviceAlreadyConnectedError: If the camera is already connected. - ConnectionError: If the specified camera index/path is not found or the camera is found but fails to open. - RuntimeError: If the camera opens but fails to apply requested FPS/resolution settings. + ConnectionError: If the specified camera index/path is not found or fails to open. + RuntimeError: If the camera opens but fails to apply requested settings. """ if self.is_connected: raise DeviceAlreadyConnectedError(f"{self} is already connected.") @@ -170,12 +165,16 @@ class OpenCVCamera(Camera): ) self._configure_capture_settings() + self._start_read_thread() - if warmup: + if warmup and self.warmup_s > 0: start_time = time.time() while time.time() - start_time < self.warmup_s: - self.read() + self.async_read(timeout_ms=self.warmup_s * 1000) time.sleep(0.1) + with self.frame_lock: + if self.latest_frame is None: + raise ConnectionError(f"{self} failed to capture frames during warmup.") logger.info(f"{self} connected.") @@ -196,8 +195,7 @@ class OpenCVCamera(Camera): Raises: RuntimeError: If the camera fails to set any of the specified properties to the requested value. - DeviceNotConnectedError: If the camera is not connected when attempting - to configure settings. + DeviceNotConnectedError: If the camera is not connected. """ if not self.is_connected: raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.") @@ -339,6 +337,17 @@ class OpenCVCamera(Camera): return found_cameras_info + def _read_from_hardware(self) -> NDArray[Any]: + if self.videocapture is None: + raise DeviceNotConnectedError(f"{self} videocapture is not initialized") + + ret, frame = self.videocapture.read() + + if not ret: + raise RuntimeError(f"{self} read failed (status={ret}).") + + return frame + def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: """ Reads a single frame synchronously from the camera. @@ -346,11 +355,6 @@ class OpenCVCamera(Camera): This is a blocking call. It waits for the next available frame from the camera hardware via OpenCV. - Args: - color_mode (Optional[ColorMode]): If specified, overrides the default - color mode (`self.color_mode`) for this read operation (e.g., - request RGB even if default is BGR). - Returns: np.ndarray: The captured frame as a NumPy array in the format (height, width, channels), using the specified or default @@ -362,34 +366,34 @@ class OpenCVCamera(Camera): received frame dimensions don't match expectations before rotation. ValueError: If an invalid `color_mode` is requested. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") start_time = time.perf_counter() - if self.videocapture is None: - raise DeviceNotConnectedError(f"{self} videocapture is not initialized") + if color_mode is not None: + logger.warning( + f"{self} read() color_mode parameter is deprecated and will be removed in future versions." + ) - ret, frame = self.videocapture.read() + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") - if not ret or frame is None: - raise RuntimeError(f"{self} read failed (status={ret}).") + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") - processed_frame = self._postprocess_image(frame, color_mode) + self.new_frame_event.clear() + frame = self.async_read(timeout_ms=10000) read_duration_ms = (time.perf_counter() - start_time) * 1e3 logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") - return processed_frame + return frame - def _postprocess_image(self, image: NDArray[Any], color_mode: ColorMode | None = None) -> NDArray[Any]: + def _postprocess_image(self, image: NDArray[Any]) -> NDArray[Any]: """ Applies color conversion, dimension validation, and rotation to a raw frame. Args: image (np.ndarray): The raw image frame (expected BGR format from OpenCV). - color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None, - uses the instance's default `self.color_mode`. Returns: np.ndarray: The processed image frame. @@ -399,11 +403,10 @@ class OpenCVCamera(Camera): RuntimeError: If the raw frame dimensions do not match the configured `width` and `height`. """ - requested_color_mode = self.color_mode if color_mode is None else color_mode - if requested_color_mode not in (ColorMode.RGB, ColorMode.BGR): + if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): raise ValueError( - f"Invalid color mode '{requested_color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}." + f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}." ) h, w, c = image.shape @@ -417,7 +420,7 @@ class OpenCVCamera(Camera): raise RuntimeError(f"{self} frame channels={c} do not match expected 3 channels (RGB/BGR).") processed_image = image - if requested_color_mode == ColorMode.RGB: + if self.color_mode == ColorMode.RGB: processed_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]: @@ -431,7 +434,7 @@ class OpenCVCamera(Camera): On each iteration: 1. Reads a color frame - 2. Stores result in latest_frame (thread-safe) + 2. Stores result in latest_frame and updates timestamp (thread-safe) 3. Sets new_frame_event to notify listeners Stops on DeviceNotConnectedError, logs other errors and continues. @@ -439,30 +442,37 @@ class OpenCVCamera(Camera): if self.stop_event is None: raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.") + failure_count = 0 while not self.stop_event.is_set(): try: - color_image = self.read() + raw_frame = self._read_from_hardware() + processed_frame = self._postprocess_image(raw_frame) + capture_time = time.perf_counter() with self.frame_lock: - self.latest_frame = color_image + self.latest_frame = processed_frame + self.latest_timestamp = capture_time self.new_frame_event.set() + failure_count = 0 except DeviceNotConnectedError: break except Exception as e: - logger.warning(f"Error reading frame in background thread for {self}: {e}") + if failure_count <= 10: + failure_count += 1 + logger.warning(f"Error reading frame in background thread for {self}: {e}") + else: + raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e def _start_read_thread(self) -> None: """Starts or restarts the background read thread if it's not running.""" - if self.thread is not None and self.thread.is_alive(): - self.thread.join(timeout=0.1) - if self.stop_event is not None: - self.stop_event.set() + self._stop_read_thread() self.stop_event = Event() self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop") self.thread.daemon = True self.thread.start() + time.sleep(0.1) def _stop_read_thread(self) -> None: """Signals the background read thread to stop and waits for it to join.""" @@ -475,6 +485,11 @@ class OpenCVCamera(Camera): self.thread = None self.stop_event = None + with self.frame_lock: + self.latest_frame = None + self.latest_timestamp = None + self.new_frame_event.clear() + def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ Reads the latest available frame asynchronously. @@ -482,6 +497,7 @@ class OpenCVCamera(Camera): This method retrieves the most recent frame captured by the background read thread. It does not block waiting for the camera hardware directly, but may wait up to timeout_ms for the background thread to provide a frame. + It is “best effort” under high FPS. Args: timeout_ms (float): Maximum time in milliseconds to wait for a frame @@ -500,13 +516,12 @@ class OpenCVCamera(Camera): raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): - self._start_read_thread() + raise RuntimeError(f"{self} read thread is not running.") if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0): - thread_alive = self.thread is not None and self.thread.is_alive() raise TimeoutError( f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. " - f"Read thread alive: {thread_alive}." + f"Read thread alive: {self.thread.is_alive()}." ) with self.frame_lock: @@ -518,6 +533,42 @@ class OpenCVCamera(Camera): return frame + def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + """Return the most recent frame captured immediately (Peeking). + + This method is non-blocking and returns whatever is currently in the + memory buffer. The frame may be stale, + meaning it could have been captured a while ago (hanging camera scenario e.g.). + + Returns: + NDArray[Any]: The frame image (numpy array). + + Raises: + TimeoutError: If the latest frame is older than `max_age_ms`. + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If the camera is connected but has not captured any frames yet. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") + + with self.frame_lock: + frame = self.latest_frame + timestamp = self.latest_timestamp + + if frame is None or timestamp is None: + raise RuntimeError(f"{self} has not captured any frames yet.") + + age_ms = (time.perf_counter() - timestamp) * 1e3 + if age_ms > max_age_ms: + raise TimeoutError( + f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)." + ) + + return frame + def disconnect(self) -> None: """ Disconnects from the camera and cleans up resources. @@ -538,4 +589,9 @@ class OpenCVCamera(Camera): self.videocapture.release() self.videocapture = None + with self.frame_lock: + self.latest_frame = None + self.latest_timestamp = None + self.new_frame_event.clear() + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py index c8916c5ee..5cede466d 100644 --- a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py +++ b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py @@ -80,6 +80,8 @@ class Reachy2Camera(Camera): self.config = config self.color_mode = config.color_mode + self.latest_frame: NDArray[Any] | None = None + self.latest_timestamp: float | None = None self.cam_manager: CameraManager | None = None @@ -125,12 +127,7 @@ class Reachy2Camera(Camera): """ Reads a single frame synchronously from the camera. - This is a blocking call. - - Args: - color_mode (Optional[ColorMode]): If specified, overrides the default - color mode (`self.color_mode`) for this read operation (e.g., - request RGB even if default is BGR). + This method retrieves the most recent frame available in Reachy 2's low-level software. Returns: np.ndarray: The captured frame as a NumPy array in the format @@ -145,6 +142,11 @@ class Reachy2Camera(Camera): if self.cam_manager is None: raise DeviceNotConnectedError(f"{self} is not connected.") + if color_mode is not None: + logger.warning( + f"{self} read() color_mode parameter is deprecated and will be removed in future versions." + ) + frame: NDArray[Any] = np.empty((0, 0, 3), dtype=np.uint8) if self.config.name == "teleop" and hasattr(self.cam_manager, "teleop"): @@ -165,11 +167,18 @@ class Reachy2Camera(Camera): raise ValueError(f"Invalid camera name '{self.config.name}'. Expected 'teleop' or 'depth'.") if frame is None: - return np.empty((0, 0, 3), dtype=np.uint8) + raise RuntimeError(f"Internal error: No frame available for {self}.") - if self.config.color_mode == "rgb": + if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): + raise ValueError( + f"Invalid color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}." + ) + if self.color_mode == ColorMode.RGB: frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + self.latest_frame = frame + self.latest_timestamp = time.perf_counter() + read_duration_ms = (time.perf_counter() - start_time) * 1e3 logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") @@ -177,13 +186,7 @@ class Reachy2Camera(Camera): def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ - Reads the latest available frame. - - This method retrieves the most recent frame available in Reachy 2's low-level software. - - Args: - timeout_ms (float): Maximum time in milliseconds to wait for a frame - to become available. Defaults to 200ms (0.2 seconds). + Same as read() Returns: np.ndarray: The latest captured frame as a NumPy array in the format @@ -197,12 +200,38 @@ class Reachy2Camera(Camera): if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") - frame = self.read() + return self.read() - if frame is None: - raise RuntimeError(f"Internal error: No frame available for {self}.") + def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + """Return the most recent frame captured immediately (Peeking). - return frame + This method is non-blocking and returns whatever is currently in the + memory buffer. The frame may be stale, + meaning it could have been captured a while ago (hanging camera scenario e.g.). + + Returns: + tuple[NDArray, float]: + - The frame image (numpy array). + - The timestamp (time.perf_counter) when this frame was captured. + + Raises: + TimeoutError: If the latest frame is older than `max_age_ms`. + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If the camera is connected but has not captured any frames yet. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.latest_frame is None or self.latest_timestamp is None: + raise RuntimeError(f"{self} has not captured any frames yet.") + + age_ms = (time.perf_counter() - self.latest_timestamp) * 1e3 + if age_ms > max_age_ms: + raise TimeoutError( + f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)." + ) + + return self.latest_frame def disconnect(self) -> None: """ diff --git a/src/lerobot/cameras/realsense/camera_realsense.py b/src/lerobot/cameras/realsense/camera_realsense.py index f2906fdd8..e47f25381 100644 --- a/src/lerobot/cameras/realsense/camera_realsense.py +++ b/src/lerobot/cameras/realsense/camera_realsense.py @@ -72,15 +72,14 @@ class RealSenseCamera(Camera): camera = RealSenseCamera(config) camera.connect() - # Read 1 frame synchronously + # Read 1 frame synchronously (blocking) color_image = camera.read() - print(color_image.shape) - # Read 1 frame asynchronously + # Read 1 frame asynchronously (waits for new frame with a timeout) async_image = camera.async_read() - # When done, properly disconnect the camera using - camera.disconnect() + # Get the latest frame immediately (no wait, returns timestamp) + latest_image, timestamp = camera.read_latest() # Example with depth capture and custom settings custom_config = RealSenseCameraConfig( @@ -133,7 +132,9 @@ class RealSenseCamera(Camera): self.thread: Thread | None = None self.stop_event: Event | None = None self.frame_lock: Lock = Lock() - self.latest_frame: NDArray[Any] | None = None + self.latest_color_frame: NDArray[Any] | None = None + self.latest_depth_frame: NDArray[Any] | None = None + self.latest_timestamp: float | None = None self.new_frame_event: Event = Event() self.rotation: int | None = get_cv2_rotation(config.rotation) @@ -158,6 +159,10 @@ class RealSenseCamera(Camera): Initializes the RealSense pipeline, configures the required streams (color and optionally depth), starts the pipeline, and validates the actual stream settings. + Args: + warmup (bool): If True, waits at connect() time until at least one valid frame + has been captured by the background thread. Defaults to True. + Raises: DeviceAlreadyConnectedError: If the camera is already connected. ValueError: If the configuration is invalid (e.g., missing serial/name, name not unique). @@ -181,15 +186,18 @@ class RealSenseCamera(Camera): ) from e self._configure_capture_settings() + self._start_read_thread() - if warmup: - time.sleep( - 1 - ) # NOTE(Steven): RS cameras need a bit of time to warm up before the first read. If we don't wait, the first read from the warmup will raise. - start_time = time.time() - while time.time() - start_time < self.warmup_s: - self.read() - time.sleep(0.1) + # NOTE(Steven/Caroline): Enforcing at least one second of warmup as RS cameras need a bit of time before the first read. If we don't wait, the first read from the warmup will raise. + self.warmup_s = max(self.warmup_s, 1) + + start_time = time.time() + while time.time() - start_time < self.warmup_s: + self.async_read(timeout_ms=self.warmup_s * 1000) + time.sleep(0.1) + with self.frame_lock: + if self.latest_color_frame is None or self.use_depth and self.latest_depth_frame is None: + raise ConnectionError(f"{self} failed to capture frames during warmup.") logger.info(f"{self} connected.") @@ -319,9 +327,6 @@ class RealSenseCamera(Camera): This is a blocking call. It waits for a coherent set of frames (depth) from the camera hardware via the RealSense pipeline. - Args: - timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms. - Returns: np.ndarray: The depth map as a NumPy array (height, width) of type `np.uint16` (raw depth values in millimeters) and rotation. @@ -330,44 +335,52 @@ class RealSenseCamera(Camera): DeviceNotConnectedError: If the camera is not connected. RuntimeError: If reading frames from the pipeline fails or frames are invalid. """ + if timeout_ms: + logger.warning( + f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions." + ) - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if not self.use_depth: raise RuntimeError( f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}." ) - start_time = time.perf_counter() + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") + + self.new_frame_event.clear() + + _ = self.async_read(timeout_ms=10000) + + with self.frame_lock: + depth_map = self.latest_depth_frame + + if depth_map is None: + raise RuntimeError("No depth frame available. Ensure camera is streaming.") + + return depth_map + + def _read_from_hardware(self): if self.rs_pipeline is None: raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.") - ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms) + ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=10000) if not ret or frame is None: - raise RuntimeError(f"{self} read_depth failed (status={ret}).") + raise RuntimeError(f"{self} read failed (status={ret}).") - depth_frame = frame.get_depth_frame() - depth_map = np.asanyarray(depth_frame.get_data()) + return frame - depth_map_processed = self._postprocess_image(depth_map, depth_frame=True) - - read_duration_ms = (time.perf_counter() - start_time) * 1e3 - logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") - - return depth_map_processed - - def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> NDArray[Any]: + def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 0) -> NDArray[Any]: """ Reads a single frame (color) synchronously from the camera. This is a blocking call. It waits for a coherent set of frames (color) from the camera hardware via the RealSense pipeline. - Args: - timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms. - Returns: np.ndarray: The captured color frame as a NumPy array (height, width, channels), processed according to `color_mode` and rotation. @@ -378,39 +391,39 @@ class RealSenseCamera(Camera): ValueError: If an invalid `color_mode` is requested. """ + start_time = time.perf_counter() + + if color_mode is not None: + logger.warning( + f"{self} read() color_mode parameter is deprecated and will be removed in future versions." + ) + + if timeout_ms: + logger.warning( + f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions." + ) + if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") - start_time = time.perf_counter() + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") - if self.rs_pipeline is None: - raise RuntimeError(f"{self}: rs_pipeline must be initialized before use.") + self.new_frame_event.clear() - ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms) - - if not ret or frame is None: - raise RuntimeError(f"{self} read failed (status={ret}).") - - color_frame = frame.get_color_frame() - color_image_raw = np.asanyarray(color_frame.get_data()) - - color_image_processed = self._postprocess_image(color_image_raw, color_mode) + frame = self.async_read(timeout_ms=10000) read_duration_ms = (time.perf_counter() - start_time) * 1e3 logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") - return color_image_processed + return frame - def _postprocess_image( - self, image: NDArray[Any], color_mode: ColorMode | None = None, depth_frame: bool = False - ) -> NDArray[Any]: + def _postprocess_image(self, image: NDArray[Any], depth_frame: bool = False) -> NDArray[Any]: """ Applies color conversion, dimension validation, and rotation to a raw color frame. Args: image (np.ndarray): The raw image frame (expected RGB format from RealSense). - color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None, - uses the instance's default `self.color_mode`. Returns: np.ndarray: The processed image frame according to `self.color_mode` and `self.rotation`. @@ -421,9 +434,9 @@ class RealSenseCamera(Camera): `width` and `height`. """ - if color_mode and color_mode not in (ColorMode.RGB, ColorMode.BGR): + if self.color_mode and self.color_mode not in (ColorMode.RGB, ColorMode.BGR): raise ValueError( - f"Invalid requested color mode '{color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}." + f"Invalid requested color mode '{self.color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}." ) if depth_frame: @@ -454,7 +467,7 @@ class RealSenseCamera(Camera): On each iteration: 1. Reads a color frame with 500ms timeout - 2. Stores result in latest_frame (thread-safe) + 2. Stores result in latest_frame and updates timestamp (thread-safe) 3. Sets new_frame_event to notify listeners Stops on DeviceNotConnectedError, logs other errors and continues. @@ -462,25 +475,41 @@ class RealSenseCamera(Camera): if self.stop_event is None: raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.") + failure_count = 0 while not self.stop_event.is_set(): try: - color_image = self.read(timeout_ms=500) + frame = self._read_from_hardware() + color_frame_raw = frame.get_color_frame() + color_frame = np.asanyarray(color_frame_raw.get_data()) + processed_color_frame = self._postprocess_image(color_frame) + + if self.use_depth: + depth_frame_raw = frame.get_depth_frame() + depth_frame = np.asanyarray(depth_frame_raw.get_data()) + processed_depth_frame = self._postprocess_image(depth_frame, depth_frame=True) + + capture_time = time.perf_counter() with self.frame_lock: - self.latest_frame = color_image + self.latest_color_frame = processed_color_frame + if self.use_depth: + self.latest_depth_frame = processed_depth_frame + self.latest_timestamp = capture_time self.new_frame_event.set() + failure_count = 0 except DeviceNotConnectedError: break except Exception as e: - logger.warning(f"Error reading frame in background thread for {self}: {e}") + if failure_count <= 10: + failure_count += 1 + logger.warning(f"Error reading frame in background thread for {self}: {e}") + else: + raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e def _start_read_thread(self) -> None: """Starts or restarts the background read thread if it's not running.""" - if self.thread is not None and self.thread.is_alive(): - self.thread.join(timeout=0.1) - if self.stop_event is not None: - self.stop_event.set() + self._stop_read_thread() self.stop_event = Event() self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop") @@ -498,6 +527,12 @@ class RealSenseCamera(Camera): self.thread = None self.stop_event = None + with self.frame_lock: + self.latest_color_frame = None + self.latest_depth_frame = None + self.latest_timestamp = None + self.new_frame_event.clear() + # NOTE(Steven): Missing implementation for depth for now def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ @@ -506,6 +541,7 @@ class RealSenseCamera(Camera): This method retrieves the most recent color frame captured by the background read thread. It does not block waiting for the camera hardware directly, but may wait up to timeout_ms for the background thread to provide a frame. + It is “best effort” under high FPS. Args: timeout_ms (float): Maximum time in milliseconds to wait for a frame @@ -524,17 +560,16 @@ class RealSenseCamera(Camera): raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): - self._start_read_thread() + raise RuntimeError(f"{self} read thread is not running.") if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0): - thread_alive = self.thread is not None and self.thread.is_alive() raise TimeoutError( f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. " - f"Read thread alive: {thread_alive}." + f"Read thread alive: {self.thread.is_alive()}." ) with self.frame_lock: - frame = self.latest_frame + frame = self.latest_color_frame self.new_frame_event.clear() if frame is None: @@ -542,6 +577,43 @@ class RealSenseCamera(Camera): return frame + # NOTE(Steven): Missing implementation for depth for now + def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + """Return the most recent (color) frame captured immediately (Peeking). + + This method is non-blocking and returns whatever is currently in the + memory buffer. The frame may be stale, + meaning it could have been captured a while ago (hanging camera scenario e.g.). + + Returns: + NDArray[Any]: The frame image (numpy array). + + Raises: + TimeoutError: If the latest frame is older than `max_age_ms`. + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If the camera is connected but has not captured any frames yet. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") + + with self.frame_lock: + frame = self.latest_color_frame + timestamp = self.latest_timestamp + + if frame is None or timestamp is None: + raise RuntimeError(f"{self} has not captured any frames yet.") + + age_ms = (time.perf_counter() - timestamp) * 1e3 + if age_ms > max_age_ms: + raise TimeoutError( + f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)." + ) + + return frame + def disconnect(self) -> None: """ Disconnects from the camera, stops the pipeline, and cleans up resources. @@ -565,4 +637,10 @@ class RealSenseCamera(Camera): self.rs_pipeline = None self.rs_profile = None + with self.frame_lock: + self.latest_color_frame = None + self.latest_depth_frame = None + self.latest_timestamp = None + self.new_frame_event.clear() + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/cameras/zmq/camera_zmq.py b/src/lerobot/cameras/zmq/camera_zmq.py index 1a4155f4b..a231a582a 100644 --- a/src/lerobot/cameras/zmq/camera_zmq.py +++ b/src/lerobot/cameras/zmq/camera_zmq.py @@ -45,6 +45,12 @@ logger = logging.getLogger(__name__) class ZMQCamera(Camera): """ + Manages camera interactions via ZeroMQ for receiving frames from a remote server. + + This class connects to a ZMQ Publisher, subscribes to frame topics, and decodes + incoming JSON messages containing Base64 encoded images. It supports both + synchronous and asynchronous frame reading patterns. + Example usage: ```python from lerobot.cameras.zmq import ZMQCamera, ZMQCameraConfig @@ -52,7 +58,16 @@ class ZMQCamera(Camera): config = ZMQCameraConfig(server_address="192.168.123.164", port=5555, camera_name="head_camera") camera = ZMQCamera(config) camera.connect() - frame = camera.read() + + # Read 1 frame synchronously (blocking) + color_image = camera.read() + + # Read 1 frame asynchronously (waits for new frame with a timeout) + async_image = camera.async_read() + + # Get the latest frame immediately (no wait, returns timestamp) + latest_image, timestamp = camera.read_latest() + camera.disconnect() ``` """ @@ -68,14 +83,17 @@ class ZMQCamera(Camera): self.color_mode = config.color_mode self.timeout_ms = config.timeout_ms + # ZMQ Context and Socket self.context: zmq.Context | None = None self.socket: zmq.Socket | None = None self._connected = False + # Threading resources self.thread: Thread | None = None self.stop_event: Event | None = None self.frame_lock: Lock = Lock() self.latest_frame: NDArray[Any] | None = None + self.latest_timestamp: float | None = None self.new_frame_event: Event = Event() def __str__(self) -> str: @@ -83,10 +101,16 @@ class ZMQCamera(Camera): @property def is_connected(self) -> bool: + """Checks if the ZMQ socket is initialized and connected.""" return self._connected and self.context is not None and self.socket is not None def connect(self, warmup: bool = True) -> None: - """Connect to ZMQ camera server.""" + """Connect to ZMQ camera server. + + Args: + warmup (bool): If True, waits for the camera to provide at least one + valid frame before returning. Defaults to True. + """ if self.is_connected: raise DeviceAlreadyConnectedError(f"{self} is already connected.") @@ -103,17 +127,28 @@ class ZMQCamera(Camera): self.socket.connect(f"tcp://{self.server_address}:{self.port}") self._connected = True - # Auto-detect resolution + # Auto-detect resolution if not provided if self.width is None or self.height is None: - h, w = self.read().shape[:2] + # Read directly from hardware because the thread isn't running yet + temp_frame = self._read_from_hardware() + h, w = temp_frame.shape[:2] self.height = h self.width = w - logger.info(f"{self} resolution: {w}x{h}") + logger.info(f"{self} resolution detected: {w}x{h}") + self._start_read_thread() logger.info(f"{self} connected.") if warmup: - time.sleep(0.1) + # Ensure we have captured at least one frame via the thread + start_time = time.time() + while time.time() - start_time < (self.config.warmup_s): # Wait a bit more than timeout + self.async_read(timeout_ms=self.config.warmup_s * 1000) + time.sleep(0.1) + + with self.frame_lock: + if self.latest_frame is None: + raise ConnectionError(f"{self} failed to capture frames during warmup.") except Exception as e: self._cleanup() @@ -134,12 +169,9 @@ class ZMQCamera(Camera): """ZMQ cameras require manual configuration (server address/port).""" return [] - def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: + def _read_from_hardware(self) -> NDArray[Any]: """ - Read a single frame from the ZMQ camera. - - Returns: - np.ndarray: Decoded frame (height, width, 3) + Reads a single frame directly from the ZMQ socket. """ if not self.is_connected or self.socket is None: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -147,6 +179,7 @@ class ZMQCamera(Camera): try: message = self.socket.recv_string() except Exception as e: + # Check for ZMQ timeout (EAGAIN/Again) without requiring global zmq import if type(e).__name__ == "Again": raise TimeoutError(f"{self} timeout after {self.timeout_ms}ms") from e raise @@ -176,42 +209,117 @@ class ZMQCamera(Camera): return frame - def _read_loop(self) -> None: - while self.stop_event and not self.stop_event.is_set(): - try: - frame = self.read() - with self.frame_lock: - self.latest_frame = frame - self.new_frame_event.set() - except DeviceNotConnectedError: - break - except TimeoutError: - pass - except Exception as e: - logger.warning(f"Read error: {e}") + def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: + """ + Reads a single frame synchronously from the camera. - def _start_read_thread(self) -> None: - if self.thread and self.thread.is_alive(): - return - self.stop_event = Event() - self.thread = Thread(target=self._read_loop, daemon=True) - self.thread.start() + This is a blocking call. It waits for the next available frame from the + camera background thread. - def _stop_read_thread(self) -> None: - if self.stop_event: - self.stop_event.set() - if self.thread and self.thread.is_alive(): - self.thread.join(timeout=2.0) - self.thread = None - self.stop_event = None + Returns: + np.ndarray: Decoded frame (height, width, 3) + """ + start_time = time.perf_counter() + + if color_mode is not None: + logger.warning( + f"{self} read() color_mode parameter is deprecated and will be removed in future versions." + ) - def async_read(self, timeout_ms: float = 10000) -> NDArray[Any]: - """Read latest frame asynchronously (non-blocking).""" if not self.is_connected: raise DeviceNotConnectedError(f"{self} is not connected.") - if not self.thread or not self.thread.is_alive(): - self._start_read_thread() + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") + + self.new_frame_event.clear() + frame = self.async_read(timeout_ms=10000) + + read_duration_ms = (time.perf_counter() - start_time) * 1e3 + logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") + + return frame + + def _read_loop(self) -> None: + """ + Internal loop run by the background thread for asynchronous reading. + """ + if self.stop_event is None: + raise RuntimeError(f"{self}: stop_event is not initialized.") + + failure_count = 0 + while not self.stop_event.is_set(): + try: + frame = self._read_from_hardware() + capture_time = time.perf_counter() + + with self.frame_lock: + self.latest_frame = frame + self.latest_timestamp = capture_time + self.new_frame_event.set() + failure_count = 0 + + except DeviceNotConnectedError: + break + except (TimeoutError, Exception) as e: + if failure_count <= 10: + failure_count += 1 + logger.warning(f"Read error: {e}") + else: + raise RuntimeError(f"{self} exceeded maximum consecutive read failures.") from e + + def _start_read_thread(self) -> None: + if self.stop_event is not None: + self.stop_event.set() + if self.thread is not None and self.thread.is_alive(): + self.thread.join(timeout=2.0) + + with self.frame_lock: + self.latest_frame = None + self.latest_timestamp = None + self.new_frame_event.clear() + + self.stop_event = Event() + self.thread = Thread(target=self._read_loop, daemon=True, name=f"{self}_read_loop") + self.thread.start() + time.sleep(0.1) + + def _stop_read_thread(self) -> None: + if self.stop_event is not None: + self.stop_event.set() + + if self.thread is not None and self.thread.is_alive(): + self.thread.join(timeout=2.0) + + self.thread = None + self.stop_event = None + + with self.frame_lock: + self.latest_frame = None + self.latest_timestamp = None + self.new_frame_event.clear() + + def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: + """ + Reads the latest available frame asynchronously. + + Args: + timeout_ms (float): Maximum time in milliseconds to wait for a frame + to become available. Defaults to 200ms. + + Returns: + np.ndarray: The latest captured frame. + + Raises: + DeviceNotConnectedError: If the camera is not connected. + TimeoutError: If no frame data becomes available within the specified timeout. + RuntimeError: If the background thread is not running. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0): raise TimeoutError(f"{self} async_read timeout after {timeout_ms}ms") @@ -225,11 +333,55 @@ class ZMQCamera(Camera): return frame + def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + """Return the most recent frame captured immediately (Peeking). + + This method is non-blocking and returns whatever is currently in the + memory buffer. The frame may be stale, + meaning it could have been captured a while ago (hanging camera scenario e.g.). + + Returns: + NDArray[Any]: The frame image (numpy array). + + Raises: + TimeoutError: If the latest frame is older than `max_age_ms`. + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If the camera is connected but has not captured any frames yet. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.thread is None or not self.thread.is_alive(): + raise RuntimeError(f"{self} read thread is not running.") + + with self.frame_lock: + frame = self.latest_frame + timestamp = self.latest_timestamp + + if frame is None or timestamp is None: + raise RuntimeError(f"{self} has not captured any frames yet.") + + age_ms = (time.perf_counter() - timestamp) * 1e3 + if age_ms > max_age_ms: + raise TimeoutError( + f"{self} latest frame is too old: {age_ms:.1f} ms (max allowed: {max_age_ms} ms)." + ) + + return frame + def disconnect(self) -> None: """Disconnect from ZMQ camera.""" - if not self.is_connected and not self.thread: + if not self.is_connected and self.thread is None: raise DeviceNotConnectedError(f"{self} not connected.") - self._stop_read_thread() + if self.thread is not None: + self._stop_read_thread() + self._cleanup() + + with self.frame_lock: + self.latest_frame = None + self.latest_timestamp = None + self.new_frame_event.clear() + logger.info(f"{self} disconnected.") diff --git a/src/lerobot/cameras/zmq/configuration_zmq.py b/src/lerobot/cameras/zmq/configuration_zmq.py index 027ae12b5..4e7732cfc 100644 --- a/src/lerobot/cameras/zmq/configuration_zmq.py +++ b/src/lerobot/cameras/zmq/configuration_zmq.py @@ -29,6 +29,7 @@ class ZMQCameraConfig(CameraConfig): camera_name: str = "zmq_camera" color_mode: ColorMode = ColorMode.RGB timeout_ms: int = 5000 + warmup_s: int = 1 def __post_init__(self) -> None: if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): diff --git a/src/lerobot/scripts/lerobot_calibrate.py b/src/lerobot/scripts/lerobot_calibrate.py index eb3df6872..1b30021dd 100644 --- a/src/lerobot/scripts/lerobot_calibrate.py +++ b/src/lerobot/scripts/lerobot_calibrate.py @@ -86,8 +86,11 @@ def calibrate(cfg: CalibrateConfig): device = make_teleoperator_from_config(cfg.device) device.connect(calibrate=False) - device.calibrate() - device.disconnect() + + try: + device.calibrate() + finally: + device.disconnect() def main(): diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index 5717dffb6..c9a559d07 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -110,25 +110,26 @@ def replay(cfg: ReplayConfig): robot.connect() - log_say("Replaying episode", cfg.play_sounds, blocking=True) - for idx in range(len(episode_frames)): - start_episode_t = time.perf_counter() + try: + log_say("Replaying episode", cfg.play_sounds, blocking=True) + for idx in range(len(episode_frames)): + start_episode_t = time.perf_counter() - action_array = actions[idx][ACTION] - action = {} - for i, name in enumerate(dataset.features[ACTION]["names"]): - action[name] = action_array[i] + action_array = actions[idx][ACTION] + action = {} + for i, name in enumerate(dataset.features[ACTION]["names"]): + action[name] = action_array[i] - robot_obs = robot.get_observation() + robot_obs = robot.get_observation() - processed_action = robot_action_processor((action, robot_obs)) + processed_action = robot_action_processor((action, robot_obs)) - _ = robot.send_action(processed_action) + _ = robot.send_action(processed_action) - dt_s = time.perf_counter() - start_episode_t - precise_sleep(max(1 / dataset.fps - dt_s, 0.0)) - - robot.disconnect() + dt_s = time.perf_counter() - start_episode_t + precise_sleep(max(1 / dataset.fps - dt_s, 0.0)) + finally: + robot.disconnect() def main(): diff --git a/tests/cameras/test_opencv.py b/tests/cameras/test_opencv.py index 3cf3793b6..feb700631 100644 --- a/tests/cameras/test_opencv.py +++ b/tests/cameras/test_opencv.py @@ -20,7 +20,9 @@ # ``` from pathlib import Path +from unittest.mock import patch +import cv2 import numpy as np import pytest @@ -28,6 +30,50 @@ from lerobot.cameras.configs import Cv2Rotation from lerobot.cameras.opencv import OpenCVCamera, OpenCVCameraConfig from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +RealVideoCapture = cv2.VideoCapture + + +class MockLoopingVideoCapture: + """ + Wraps the real OpenCV VideoCapture. + Motivation: cv2.VideoCapture(file.png) is only valid for one read. + Strategy: Read the file once & return the cached frame for subsequent reads. + Consequence: No recurrent I/O operations, but we keep the test artifacts simple. + """ + + def __init__(self, *args, **kwargs): + args_clean = [str(a) if isinstance(a, Path) else a for a in args] + self._real_vc = RealVideoCapture(*args_clean, **kwargs) + self._cached_frame = None + + def read(self): + ret, frame = self._real_vc.read() + + if ret: + self._cached_frame = frame + return ret, frame + + if not ret and self._cached_frame is not None: + return True, self._cached_frame.copy() + + return ret, frame + + def __getattr__(self, name): + return getattr(self._real_vc, name) + + +@pytest.fixture(autouse=True) +def patch_opencv_videocapture(): + """ + Automatically patches cv2.VideoCapture for all tests. + """ + module_path = OpenCVCamera.__module__ + target = f"{module_path}.cv2.VideoCapture" + + with patch(target, new=MockLoopingVideoCapture): + yield + + # NOTE(Steven): more tests + assertions? TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "cameras" DEFAULT_PNG_FILE_PATH = TEST_ARTIFACTS_DIR / "image_160x120.png" @@ -43,25 +89,22 @@ def test_abc_implementation(): def test_connect(): - config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) - camera = OpenCVCamera(config) + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0) - camera.connect(warmup=False) - - assert camera.is_connected + with OpenCVCamera(config) as camera: + assert camera.is_connected def test_connect_already_connected(): - config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) - camera = OpenCVCamera(config) - camera.connect(warmup=False) + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0) - with pytest.raises(DeviceAlreadyConnectedError): - camera.connect(warmup=False) + with OpenCVCamera(config) as camera, pytest.raises(DeviceAlreadyConnectedError): + camera.connect() def test_connect_invalid_camera_path(): config = OpenCVCameraConfig(index_or_path="nonexistent/camera.png") + camera = OpenCVCamera(config) with pytest.raises(ConnectionError): @@ -74,27 +117,25 @@ def test_invalid_width_connect(): width=99999, # Invalid width to trigger error height=480, ) - camera = OpenCVCamera(config) + camera = OpenCVCamera(config) with pytest.raises(RuntimeError): camera.connect(warmup=False) @pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES) def test_read(index_or_path): - config = OpenCVCameraConfig(index_or_path=index_or_path) - camera = OpenCVCamera(config) - camera.connect(warmup=False) + config = OpenCVCameraConfig(index_or_path=index_or_path, warmup_s=0) - img = camera.read() - - assert isinstance(img, np.ndarray) + with OpenCVCamera(config) as camera: + img = camera.read() + assert isinstance(img, np.ndarray) def test_read_before_connect(): config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) - camera = OpenCVCamera(config) + camera = OpenCVCamera(config) with pytest.raises(DeviceNotConnectedError): _ = camera.read() @@ -119,32 +160,22 @@ def test_disconnect_before_connect(): @pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES) def test_async_read(index_or_path): - config = OpenCVCameraConfig(index_or_path=index_or_path) - camera = OpenCVCamera(config) - camera.connect(warmup=False) + config = OpenCVCameraConfig(index_or_path=index_or_path, warmup_s=0) - try: + with OpenCVCamera(config) as camera: img = camera.async_read() assert camera.thread is not None assert camera.thread.is_alive() assert isinstance(img, np.ndarray) - finally: - if camera.is_connected: - camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends def test_async_read_timeout(): - config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) - camera = OpenCVCamera(config) - camera.connect(warmup=False) + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0) - try: - with pytest.raises(TimeoutError): - camera.async_read(timeout_ms=0) - finally: - if camera.is_connected: - camera.disconnect() + with OpenCVCamera(config) as camera, pytest.raises(TimeoutError): + camera.async_read(timeout_ms=0) # consumes any available frame by then + camera.async_read(timeout_ms=0) # request immediately another one def test_async_read_before_connect(): @@ -155,6 +186,50 @@ def test_async_read_before_connect(): _ = camera.async_read() +def test_read_latest(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0) + + with OpenCVCamera(config) as camera: + # ensure at least one fresh frame is captured + frame = camera.read() + latest = camera.read_latest() + + assert isinstance(latest, np.ndarray) + assert latest.shape == frame.shape + + +def test_read_latest_before_connect(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) + + camera = OpenCVCamera(config) + with pytest.raises(DeviceNotConnectedError): + _ = camera.read_latest() + + +def test_read_latest_high_frequency(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0) + + with OpenCVCamera(config) as camera: + # prime to ensure frames are available + ref = camera.read() + + for _ in range(20): + latest = camera.read_latest() + assert isinstance(latest, np.ndarray) + assert latest.shape == ref.shape + + +def test_read_latest_too_old(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, warmup_s=0) + + with OpenCVCamera(config) as camera: + # prime to ensure frames are available + _ = camera.read() + + with pytest.raises(TimeoutError): + _ = camera.read_latest(max_age_ms=0) # immediately too old + + def test_fourcc_configuration(): """Test FourCC configuration validation and application.""" @@ -181,18 +256,15 @@ def test_fourcc_configuration(): def test_fourcc_with_camera(): """Test FourCC functionality with actual camera connection.""" - config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, fourcc="MJPG") - camera = OpenCVCamera(config) + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH, fourcc="MJPG", warmup_s=0) # Connect should work with MJPG specified - camera.connect(warmup=False) - assert camera.is_connected + with OpenCVCamera(config) as camera: + assert camera.is_connected - # Read should work normally - img = camera.read() - assert isinstance(img, np.ndarray) - - camera.disconnect() + # Read should work normally + img = camera.read() + assert isinstance(img, np.ndarray) @pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES) @@ -211,18 +283,16 @@ def test_rotation(rotation, index_or_path): dimensions = filename.split("_")[-1].split(".")[0] # Assumes filenames format (_wxh.png) original_width, original_height = map(int, dimensions.split("x")) - config = OpenCVCameraConfig(index_or_path=index_or_path, rotation=rotation) - camera = OpenCVCamera(config) - camera.connect(warmup=False) + config = OpenCVCameraConfig(index_or_path=index_or_path, rotation=rotation, warmup_s=0) + with OpenCVCamera(config) as camera: + img = camera.read() + assert isinstance(img, np.ndarray) - img = camera.read() - assert isinstance(img, np.ndarray) - - if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270): - assert camera.width == original_height - assert camera.height == original_width - assert img.shape[:2] == (original_width, original_height) - else: - assert camera.width == original_width - assert camera.height == original_height - assert img.shape[:2] == (original_height, original_width) + if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270): + assert camera.width == original_height + assert camera.height == original_width + assert img.shape[:2] == (original_width, original_height) + else: + assert camera.width == original_width + assert camera.height == original_height + assert img.shape[:2] == (original_height, original_width) diff --git a/tests/cameras/test_reachy2_camera.py b/tests/cameras/test_reachy2_camera.py index 14774bf38..2aebfdf0a 100644 --- a/tests/cameras/test_reachy2_camera.py +++ b/tests/cameras/test_reachy2_camera.py @@ -150,6 +150,44 @@ def test_async_read_before_connect(camera): _ = camera.async_read() +def test_read_latest(camera): + camera.connect() + + frame = camera.read() + latest = camera.read_latest() + + assert isinstance(latest, np.ndarray) + assert latest.shape == frame.shape + + +def test_read_latest_before_connect(camera): + # camera fixture yields an unconnected camera instance + with pytest.raises(DeviceNotConnectedError): + _ = camera.read_latest() + + +def test_read_latest_high_frequency(camera): + camera.connect() + + # prime to ensure frames are available + ref = camera.read() + + for _ in range(20): + latest = camera.read_latest() + assert isinstance(latest, np.ndarray) + assert latest.shape == ref.shape + + +def test_read_latest_too_old(camera): + camera.connect() + + # prime to ensure frames are available + _ = camera.read() + + with pytest.raises(TimeoutError): + _ = camera.read_latest(max_age_ms=0) # immediately too old + + def test_wrong_camera_name(): with pytest.raises(ValueError): _ = Reachy2CameraConfig(name="wrong-name", image_type="left") diff --git a/tests/cameras/test_realsense.py b/tests/cameras/test_realsense.py index fb9912257..1deb73f05 100644 --- a/tests/cameras/test_realsense.py +++ b/tests/cameras/test_realsense.py @@ -62,19 +62,15 @@ def test_abc_implementation(): def test_connect(): - config = RealSenseCameraConfig(serial_number_or_name="042") - camera = RealSenseCamera(config) + config = RealSenseCameraConfig(serial_number_or_name="042", warmup_s=0) - camera.connect(warmup=False) - assert camera.is_connected + with RealSenseCamera(config) as camera: + assert camera.is_connected def test_connect_already_connected(): - config = RealSenseCameraConfig(serial_number_or_name="042") - camera = RealSenseCamera(config) - camera.connect(warmup=False) - - with pytest.raises(DeviceAlreadyConnectedError): + config = RealSenseCameraConfig(serial_number_or_name="042", warmup_s=0) + with RealSenseCamera(config) as camera, pytest.raises(DeviceAlreadyConnectedError): camera.connect(warmup=False) @@ -96,12 +92,10 @@ def test_invalid_width_connect(): def test_read(): - config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30) - camera = RealSenseCamera(config) - camera.connect(warmup=False) - - img = camera.read() - assert isinstance(img, np.ndarray) + config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0) + with RealSenseCamera(config) as camera: + img = camera.read() + assert isinstance(img, np.ndarray) # TODO(Steven): Fix this test for the latest version of pyrealsense2. @@ -142,32 +136,21 @@ def test_disconnect_before_connect(): def test_async_read(): - config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30) - camera = RealSenseCamera(config) - camera.connect(warmup=False) + config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0) - try: + with RealSenseCamera(config) as camera: img = camera.async_read() assert camera.thread is not None assert camera.thread.is_alive() assert isinstance(img, np.ndarray) - finally: - if camera.is_connected: - camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends def test_async_read_timeout(): - config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30) - camera = RealSenseCamera(config) - camera.connect(warmup=False) - - try: - with pytest.raises(TimeoutError): - camera.async_read(timeout_ms=0) - finally: - if camera.is_connected: - camera.disconnect() + config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0) + with RealSenseCamera(config) as camera, pytest.raises(TimeoutError): + camera.async_read(timeout_ms=0) # consumes any available frame by then + camera.async_read(timeout_ms=0) # request immediately another one def test_async_read_before_connect(): @@ -178,6 +161,47 @@ def test_async_read_before_connect(): _ = camera.async_read() +def test_read_latest(): + config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0) + with RealSenseCamera(config) as camera: + img = camera.read() + latest = camera.read_latest() + + assert isinstance(latest, np.ndarray) + assert latest.shape == img.shape + + +def test_read_latest_high_frequency(): + config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, warmup_s=0) + with RealSenseCamera(config) as camera: + # prime with one read to ensure frames are available + ref = camera.read() + + for _ in range(20): + latest = camera.read_latest() + assert isinstance(latest, np.ndarray) + assert latest.shape == ref.shape + + +def test_read_latest_before_connect(): + config = RealSenseCameraConfig(serial_number_or_name="042") + camera = RealSenseCamera(config) + + with pytest.raises(DeviceNotConnectedError): + _ = camera.read_latest() + + +def test_read_latest_too_old(): + config = RealSenseCameraConfig(serial_number_or_name="042") + + with RealSenseCamera(config) as camera: + # prime to ensure frames are available + _ = camera.read() + + with pytest.raises(TimeoutError): + _ = camera.read_latest(max_age_ms=0) # immediately too old + + @pytest.mark.parametrize( "rotation", [ @@ -189,18 +213,16 @@ def test_async_read_before_connect(): ids=["no_rot", "rot90", "rot180", "rot270"], ) def test_rotation(rotation): - config = RealSenseCameraConfig(serial_number_or_name="042", rotation=rotation) - camera = RealSenseCamera(config) - camera.connect(warmup=False) + config = RealSenseCameraConfig(serial_number_or_name="042", rotation=rotation, warmup_s=0) + with RealSenseCamera(config) as camera: + img = camera.read() + assert isinstance(img, np.ndarray) - img = camera.read() - assert isinstance(img, np.ndarray) - - if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270): - assert camera.width == 480 - assert camera.height == 640 - assert img.shape[:2] == (640, 480) - else: - assert camera.width == 640 - assert camera.height == 480 - assert img.shape[:2] == (480, 640) + if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270): + assert camera.width == 480 + assert camera.height == 640 + assert img.shape[:2] == (640, 480) + else: + assert camera.width == 640 + assert camera.height == 480 + assert img.shape[:2] == (480, 640) From 04cbf669cf0565950f8ba66e8a03a66bd8f20d7a Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Fri, 30 Jan 2026 12:23:22 +0100 Subject: [PATCH 071/111] fix(sac): make temperature a property to fix checkpoint resume bug (#2877) * fix(sac): make temperature a property to fix checkpoint resume bug Temperature was stored as a plain float and not restored after loading a checkpoint, causing incorrect loss computations until update_temperature() was called. Changed to a property that always computes from log_alpha, ensuring correct behavior after checkpoint loading. * simplify docstrings --- src/lerobot/policies/sac/modeling_sac.py | 11 ++++++----- src/lerobot/rl/learner.py | 3 --- tests/policies/test_sac_policy.py | 3 ++- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/lerobot/policies/sac/modeling_sac.py b/src/lerobot/policies/sac/modeling_sac.py index c7c6798ed..d5dd71a48 100644 --- a/src/lerobot/policies/sac/modeling_sac.py +++ b/src/lerobot/policies/sac/modeling_sac.py @@ -239,8 +239,10 @@ class SACPolicy( + target_param.data * (1.0 - self.config.critic_target_update_weight) ) - def update_temperature(self): - self.temperature = self.log_alpha.exp().item() + @property + def temperature(self) -> float: + """Return the current temperature value, always in sync with log_alpha.""" + return self.log_alpha.exp().item() def compute_loss_critic( self, @@ -457,11 +459,10 @@ class SACPolicy( dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0) self.target_entropy = -np.prod(dim) / 2 - def _init_temperature(self): - """Set up temperature parameter and initial log_alpha.""" + def _init_temperature(self) -> None: + """Set up temperature parameter (log_alpha).""" temp_init = self.config.temperature_init self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)])) - self.temperature = self.log_alpha.exp().item() class SACObservationEncoder(nn.Module): diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index abc5c9504..ee09ac9ac 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -545,9 +545,6 @@ def add_actor_information_and_train( training_infos["temperature_grad_norm"] = temp_grad_norm training_infos["temperature"] = policy.temperature - # Update temperature - policy.update_temperature() - # Push policy to actors if needed if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) diff --git a/tests/policies/test_sac_policy.py b/tests/policies/test_sac_policy.py index 8576883bd..6fad2979e 100644 --- a/tests/policies/test_sac_policy.py +++ b/tests/policies/test_sac_policy.py @@ -441,12 +441,13 @@ def test_sac_policy_with_predefined_entropy(): def test_sac_policy_update_temperature(): + """Test that temperature property is always in sync with log_alpha.""" config = create_default_config(continuous_action_dim=10, state_dim=10) policy = SACPolicy(config=config) assert policy.temperature == pytest.approx(1.0) policy.log_alpha.data = torch.tensor([math.log(0.1)]) - policy.update_temperature() + # Temperature property automatically reflects log_alpha changes assert policy.temperature == pytest.approx(0.1) From ec04b7ce3aca23491e42232c1ae723bb4b981993 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Fri, 30 Jan 2026 13:19:42 +0100 Subject: [PATCH 072/111] Feat(dataset_tools.py) Add modify tasks tool (#2875) * feat(datasets): add modify_tasks function for in-place task editing Add a new utility function to modify tasks in LeRobotDataset in-place. This allows users to: - Set a single task for all episodes - Set specific tasks for individual episodes - Combine a default task with per-episode overrides * feat(edit-dataset): add CLI support for modify_tasks operation Integrate the modify_tasks function into lerobot_edit_dataset CLI. Users can now modify dataset tasks via command line: Supports setting a default task, per-episode tasks, or both combined. * test(datasets): add tests for modify_tasks function Add comprehensive test coverage for the modify_tasks utility: - Single task for all episodes - Episode-specific task assignment - Default task with per-episode overrides - Error handling for missing/invalid arguments - Verification of task_index correctness - In-place modification behavior - Metadata preservation * respond to copilot review --- src/lerobot/datasets/dataset_tools.py | 126 +++++++++++++++ src/lerobot/scripts/lerobot_edit_dataset.py | 82 +++++++++- tests/datasets/test_dataset_tools.py | 169 ++++++++++++++++++++ 3 files changed, 374 insertions(+), 3 deletions(-) diff --git a/src/lerobot/datasets/dataset_tools.py b/src/lerobot/datasets/dataset_tools.py index e2928e2a6..123d455c6 100644 --- a/src/lerobot/datasets/dataset_tools.py +++ b/src/lerobot/datasets/dataset_tools.py @@ -1396,6 +1396,132 @@ BYTES_PER_KIB = 1024 BYTES_PER_MIB = BYTES_PER_KIB * BYTES_PER_KIB +def modify_tasks( + dataset: LeRobotDataset, + new_task: str | None = None, + episode_tasks: dict[int, str] | None = None, +) -> LeRobotDataset: + """Modify tasks in a LeRobotDataset. + + This function allows you to either: + 1. Set a single task for the entire dataset (using `new_task`) + 2. Set specific tasks for specific episodes (using `episode_tasks`) + + You can combine both: `new_task` sets the default, and `episode_tasks` overrides + specific episodes. + + The dataset is modified in-place, updating only the task-related files: + - meta/tasks.parquet + - data/**/*.parquet (task_index column) + - meta/episodes/**/*.parquet (tasks column) + - meta/info.json (total_tasks) + + Args: + dataset: The source LeRobotDataset to modify. + new_task: A single task string to apply to all episodes. If None and episode_tasks + is also None, raises an error. + episode_tasks: Optional dict mapping episode indices to their task strings. + Overrides `new_task` for specific episodes. + + + Examples: + Set a single task for all episodes: + dataset = modify_tasks(dataset, new_task="Pick up the cube") + + Set different tasks for specific episodes: + dataset = modify_tasks( + dataset, + episode_tasks={0: "Task A", 1: "Task B", 2: "Task A"} + ) + + Set a default task with overrides: + dataset = modify_tasks( + dataset, + new_task="Default task", + episode_tasks={5: "Special task for episode 5"} + ) + """ + if new_task is None and episode_tasks is None: + raise ValueError("Must specify at least one of new_task or episode_tasks") + + if episode_tasks is not None: + valid_indices = set(range(dataset.meta.total_episodes)) + invalid = set(episode_tasks.keys()) - valid_indices + if invalid: + raise ValueError(f"Invalid episode indices: {invalid}") + + # Ensure episodes metadata is loaded + if dataset.meta.episodes is None: + dataset.meta.episodes = load_episodes(dataset.root) + + # Build the mapping from episode index to task string + episode_to_task: dict[int, str] = {} + for ep_idx in range(dataset.meta.total_episodes): + if episode_tasks and ep_idx in episode_tasks: + episode_to_task[ep_idx] = episode_tasks[ep_idx] + elif new_task is not None: + episode_to_task[ep_idx] = new_task + else: + # Keep original task if not overridden and no default provided + original_tasks = dataset.meta.episodes[ep_idx]["tasks"] + if not original_tasks: + raise ValueError(f"Episode {ep_idx} has no tasks and no default task was provided") + episode_to_task[ep_idx] = original_tasks[0] + + # Collect all unique tasks and create new task mapping + unique_tasks = sorted(set(episode_to_task.values())) + new_task_df = pd.DataFrame({"task_index": list(range(len(unique_tasks)))}, index=unique_tasks) + task_to_index = {task: idx for idx, task in enumerate(unique_tasks)} + + logging.info(f"Modifying tasks in {dataset.repo_id}") + logging.info(f"New tasks: {unique_tasks}") + + root = dataset.root + + # Update data files - modify task_index column + logging.info("Updating data files...") + data_dir = root / DATA_DIR + + for parquet_path in tqdm(sorted(data_dir.rglob("*.parquet")), desc="Updating data"): + df = pd.read_parquet(parquet_path) + + # Build a mapping from episode_index to new task_index for rows in this file + episode_indices_in_file = df["episode_index"].unique() + ep_to_new_task_idx = { + ep_idx: task_to_index[episode_to_task[ep_idx]] for ep_idx in episode_indices_in_file + } + + # Update task_index column + df["task_index"] = df["episode_index"].map(ep_to_new_task_idx) + df.to_parquet(parquet_path, index=False) + + # Update episodes metadata - modify tasks column + logging.info("Updating episodes metadata...") + episodes_dir = root / "meta" / "episodes" + + for parquet_path in tqdm(sorted(episodes_dir.rglob("*.parquet")), desc="Updating episodes"): + df = pd.read_parquet(parquet_path) + + # Update tasks column + df["tasks"] = df["episode_index"].apply(lambda ep_idx: [episode_to_task[ep_idx]]) + df.to_parquet(parquet_path, index=False) + + # Write new tasks.parquet + write_tasks(new_task_df, root) + + # Update info.json + dataset.meta.info["total_tasks"] = len(unique_tasks) + write_info(dataset.meta.info, root) + + # Reload metadata to reflect changes + dataset.meta.tasks = new_task_df + dataset.meta.episodes = load_episodes(root) + + logging.info(f"Tasks: {unique_tasks}") + + return dataset + + def convert_image_to_video_dataset( dataset: LeRobotDataset, output_dir: Path, diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index 4ba6ce44f..2ca9c520d 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -18,7 +18,7 @@ Edit LeRobot datasets using various transformation tools. This script allows you to delete episodes, split datasets, merge datasets, -remove features, and convert image datasets to video format. +remove features, modify tasks, and convert image datasets to video format. When new_repo_id is specified, creates a new dataset. Usage Examples: @@ -66,6 +66,25 @@ Remove camera feature: --operation.type remove_feature \ --operation.feature_names "['observation.images.top']" +Modify tasks - set a single task for all episodes (WARNING: modifies in-place): + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type modify_tasks \ + --operation.new_task "Pick up the cube and place it" + +Modify tasks - set different tasks for specific episodes (WARNING: modifies in-place): + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type modify_tasks \ + --operation.episode_tasks '{"0": "Task A", "1": "Task B", "2": "Task A"}' + +Modify tasks - set default task with overrides for specific episodes (WARNING: modifies in-place): + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht \ + --operation.type modify_tasks \ + --operation.new_task "Default task" \ + --operation.episode_tasks '{"5": "Special task for episode 5"}' + Convert image dataset to video format and save locally: python -m lerobot.scripts.lerobot_edit_dataset \ --repo_id lerobot/pusht_image \ @@ -100,6 +119,7 @@ from lerobot.datasets.dataset_tools import ( convert_image_to_video_dataset, delete_episodes, merge_datasets, + modify_tasks, remove_feature, split_dataset, ) @@ -132,6 +152,13 @@ class RemoveFeatureConfig: feature_names: list[str] | None = None +@dataclass +class ModifyTasksConfig: + type: str = "modify_tasks" + new_task: str | None = None + episode_tasks: dict[str, str] | None = None + + @dataclass class ConvertImageToVideoConfig: type: str = "convert_image_to_video" @@ -151,7 +178,12 @@ class ConvertImageToVideoConfig: class EditDatasetConfig: repo_id: str operation: ( - DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig + DeleteEpisodesConfig + | SplitConfig + | MergeConfig + | RemoveFeatureConfig + | ModifyTasksConfig + | ConvertImageToVideoConfig ) root: str | None = None new_repo_id: str | None = None @@ -296,6 +328,48 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None: LeRobotDataset(output_repo_id, root=output_dir).push_to_hub() +def handle_modify_tasks(cfg: EditDatasetConfig) -> None: + if not isinstance(cfg.operation, ModifyTasksConfig): + raise ValueError("Operation config must be ModifyTasksConfig") + + new_task = cfg.operation.new_task + episode_tasks_raw = cfg.operation.episode_tasks + + if new_task is None and episode_tasks_raw is None: + raise ValueError("Must specify at least one of new_task or episode_tasks for modify_tasks operation") + + # Warn about in-place modification behavior + if cfg.new_repo_id is not None: + logging.warning("modify_tasks modifies datasets in-place. The --new_repo_id parameter is ignored.") + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + logging.warning(f"Modifying dataset in-place at {dataset.root}. Original data will be overwritten.") + + # Convert episode_tasks keys from string to int if needed (CLI passes strings) + episode_tasks: dict[int, str] | None = None + if episode_tasks_raw is not None: + episode_tasks = {int(k): v for k, v in episode_tasks_raw.items()} + + logging.info(f"Modifying tasks in {cfg.repo_id}") + if new_task: + logging.info(f" Default task: '{new_task}'") + if episode_tasks: + logging.info(f" Episode-specific tasks: {episode_tasks}") + + modified_dataset = modify_tasks( + dataset, + new_task=new_task, + episode_tasks=episode_tasks, + ) + + logging.info(f"Dataset modified at {dataset.root}") + logging.info(f"Tasks: {list(modified_dataset.meta.tasks.index)}") + + if cfg.push_to_hub: + logging.info(f"Pushing to hub as {cfg.repo_id}") + modified_dataset.push_to_hub() + + def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None: # Note: Parser may create any config type with the right fields, so we access fields directly # instead of checking isinstance() @@ -371,12 +445,14 @@ def edit_dataset(cfg: EditDatasetConfig) -> None: handle_merge(cfg) elif operation_type == "remove_feature": handle_remove_feature(cfg) + elif operation_type == "modify_tasks": + handle_modify_tasks(cfg) elif operation_type == "convert_image_to_video": handle_convert_image_to_video(cfg) else: raise ValueError( f"Unknown operation type: {operation_type}\n" - f"Available operations: delete_episodes, split, merge, remove_feature, convert_to_video" + f"Available operations: delete_episodes, split, merge, remove_feature, modify_tasks, convert_image_to_video" ) diff --git a/tests/datasets/test_dataset_tools.py b/tests/datasets/test_dataset_tools.py index 35a369de9..1de199630 100644 --- a/tests/datasets/test_dataset_tools.py +++ b/tests/datasets/test_dataset_tools.py @@ -26,6 +26,7 @@ from lerobot.datasets.dataset_tools import ( delete_episodes, merge_datasets, modify_features, + modify_tasks, remove_feature, split_dataset, ) @@ -1050,6 +1051,174 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path): assert "reward" in modified_dataset.meta.features +def test_modify_tasks_single_task_for_all(sample_dataset): + """Test setting a single task for all episodes.""" + new_task = "Pick up the cube and place it" + + modified_dataset = modify_tasks(sample_dataset, new_task=new_task) + + # Verify all episodes have the new task + assert len(modified_dataset.meta.tasks) == 1 + assert new_task in modified_dataset.meta.tasks.index + + # Verify task_index is 0 for all frames (only one task) + for i in range(len(modified_dataset)): + item = modified_dataset[i] + assert item["task_index"].item() == 0 + assert item["task"] == new_task + + +def test_modify_tasks_episode_specific(sample_dataset): + """Test setting different tasks for specific episodes.""" + episode_tasks = { + 0: "Task A", + 1: "Task B", + 2: "Task A", + 3: "Task C", + 4: "Task B", + } + + modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks) + + # Verify correct number of unique tasks + unique_tasks = set(episode_tasks.values()) + assert len(modified_dataset.meta.tasks) == len(unique_tasks) + + # Verify each episode has the correct task + for ep_idx, expected_task in episode_tasks.items(): + ep_data = modified_dataset.meta.episodes[ep_idx] + assert ep_data["tasks"][0] == expected_task + + +def test_modify_tasks_default_with_overrides(sample_dataset): + """Test setting a default task with specific overrides.""" + default_task = "Default task" + override_task = "Special task" + episode_tasks = {2: override_task, 4: override_task} + + modified_dataset = modify_tasks( + sample_dataset, + new_task=default_task, + episode_tasks=episode_tasks, + ) + + # Verify correct number of unique tasks + assert len(modified_dataset.meta.tasks) == 2 + assert default_task in modified_dataset.meta.tasks.index + assert override_task in modified_dataset.meta.tasks.index + + # Verify episodes have correct tasks + for ep_idx in range(5): + ep_data = modified_dataset.meta.episodes[ep_idx] + if ep_idx in episode_tasks: + assert ep_data["tasks"][0] == override_task + else: + assert ep_data["tasks"][0] == default_task + + +def test_modify_tasks_no_task_specified(sample_dataset): + """Test error when no task is specified.""" + with pytest.raises(ValueError, match="Must specify at least one of new_task or episode_tasks"): + modify_tasks(sample_dataset) + + +def test_modify_tasks_invalid_episode_indices(sample_dataset): + """Test error with invalid episode indices.""" + with pytest.raises(ValueError, match="Invalid episode indices"): + modify_tasks(sample_dataset, episode_tasks={10: "Task", 20: "Task"}) + + +def test_modify_tasks_updates_info_json(sample_dataset): + """Test that total_tasks is updated in info.json.""" + episode_tasks = {0: "Task A", 1: "Task B", 2: "Task C", 3: "Task A", 4: "Task B"} + + modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks) + + # Verify total_tasks is updated + assert modified_dataset.meta.total_tasks == 3 + + +def test_modify_tasks_preserves_other_metadata(sample_dataset): + """Test that modifying tasks preserves other metadata.""" + original_frames = sample_dataset.meta.total_frames + original_episodes = sample_dataset.meta.total_episodes + original_fps = sample_dataset.meta.fps + + modified_dataset = modify_tasks(sample_dataset, new_task="New task") + + # Verify other metadata is preserved + assert modified_dataset.meta.total_frames == original_frames + assert modified_dataset.meta.total_episodes == original_episodes + assert modified_dataset.meta.fps == original_fps + + +def test_modify_tasks_task_index_correct(sample_dataset): + """Test that task_index values are correct in data files.""" + # Create tasks that will have predictable indices (sorted alphabetically) + episode_tasks = { + 0: "Alpha task", # Will be index 0 + 1: "Beta task", # Will be index 1 + 2: "Alpha task", # Will be index 0 + 3: "Gamma task", # Will be index 2 + 4: "Beta task", # Will be index 1 + } + + modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks) + + # Verify task indices are correct + task_to_expected_idx = { + "Alpha task": 0, + "Beta task": 1, + "Gamma task": 2, + } + + for i in range(len(modified_dataset)): + item = modified_dataset[i] + ep_idx = item["episode_index"].item() + expected_task = episode_tasks[ep_idx] + expected_idx = task_to_expected_idx[expected_task] + assert item["task_index"].item() == expected_idx + assert item["task"] == expected_task + + +def test_modify_tasks_in_place(sample_dataset): + """Test that modify_tasks modifies the dataset in-place.""" + original_root = sample_dataset.root + + modified_dataset = modify_tasks(sample_dataset, new_task="New task") + + # Verify same instance is returned and root is unchanged + assert modified_dataset is sample_dataset + assert modified_dataset.root == original_root + + +def test_modify_tasks_keeps_original_when_not_overridden(sample_dataset): + """Test that original tasks are kept when using episode_tasks without new_task.""" + from lerobot.datasets.utils import load_episodes + + # Ensure episodes metadata is loaded + if sample_dataset.meta.episodes is None: + sample_dataset.meta.episodes = load_episodes(sample_dataset.meta.root) + + # Get original tasks for episodes not being overridden + original_task_ep0 = sample_dataset.meta.episodes[0]["tasks"][0] + original_task_ep1 = sample_dataset.meta.episodes[1]["tasks"][0] + + # Only override episodes 2, 3, 4 + episode_tasks = {2: "New Task A", 3: "New Task B", 4: "New Task A"} + + modified_dataset = modify_tasks(sample_dataset, episode_tasks=episode_tasks) + + # Verify original tasks are kept for episodes 0 and 1 + assert modified_dataset.meta.episodes[0]["tasks"][0] == original_task_ep0 + assert modified_dataset.meta.episodes[1]["tasks"][0] == original_task_ep1 + + # Verify new tasks for overridden episodes + assert modified_dataset.meta.episodes[2]["tasks"][0] == "New Task A" + assert modified_dataset.meta.episodes[3]["tasks"][0] == "New Task B" + assert modified_dataset.meta.episodes[4]["tasks"][0] == "New Task A" + + def test_convert_image_to_video_dataset(tmp_path): """Test converting lerobot/pusht_image dataset to video format.""" from lerobot.datasets.lerobot_dataset import LeRobotDataset From 55c0471db9e440e99e801e2e67d645ecd7fdb9d5 Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Fri, 30 Jan 2026 16:57:56 +0100 Subject: [PATCH 073/111] docs(cameras): revising and improving docs on cameras (#2878) * docs(cameras): revising and improving docs on cameras * resolving copilot comments --- docs/source/_toctree.yml | 6 +- docs/source/cameras.mdx | 176 +++++++++++++++++++++------------------ 2 files changed, 99 insertions(+), 83 deletions(-) diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index eb97117af..98417f134 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -7,8 +7,6 @@ - sections: - local: il_robots title: Imitation Learning for Robots - - local: cameras - title: Cameras - local: bring_your_own_policies title: Bring Your Own Policies - local: integrate_hardware @@ -108,6 +106,10 @@ - local: phone_teleop title: Phone title: "Teleoperators" +- sections: + - local: cameras + title: Cameras + title: "Sensors" - sections: - local: torch_accelerators title: PyTorch accelerators diff --git a/docs/source/cameras.mdx b/docs/source/cameras.mdx index 5c35be0ba..8af0f5ae5 100644 --- a/docs/source/cameras.mdx +++ b/docs/source/cameras.mdx @@ -1,12 +1,22 @@ # Cameras -LeRobot offers multiple options for video capture, including phone cameras, built-in laptop cameras, external webcams, and Intel RealSense cameras. To efficiently record frames from most cameras, you can use either the `OpenCVCamera` or `RealSenseCamera` class. For additional compatibility details on the `OpenCVCamera` class, refer to the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html). +LeRobot offers multiple options for video capture: -### Finding your camera +| Class | Supported Cameras | +| ----------------- | ----------------------------------- | +| `OpenCVCamera` | Phone, built-in laptop, USB webcams | +| `ZMQCamera` | Network-connected cameras | +| `RealSenseCamera` | Intel RealSense (with depth) | +| `Reachy2Camera` | Reachy 2 robot cameras | -To instantiate a camera, you need a camera identifier. This identifier might change if you reboot your computer or re-plug your camera, a behavior mostly dependant on your operating system. +> [!TIP] +> For `OpenCVCamera` compatibility details, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html). -To find the camera indices of the cameras plugged into your system, run the following script: +### Find your camera + +Every camera requires a unique identifier to be instantiated, allowing you to distinguish between multiple connected devices. + +`OpenCVCamera` and `RealSenseCamera` support auto-discovery. Run the command below to list available devices and their identifiers. Note that these identifiers may change after rebooting your computer or re-plugging the camera, depending on your operating system. ```bash lerobot-find-cameras opencv # or realsense for Intel Realsense cameras @@ -14,7 +24,7 @@ lerobot-find-cameras opencv # or realsense for Intel Realsense cameras The output will look something like this if you have two cameras connected: -``` +```bash --- Detected Cameras --- Camera #0: Name: OpenCV Camera @ 0 @@ -33,13 +43,37 @@ Camera #0: > [!WARNING] > When using Intel RealSense cameras in `macOS`, you could get this [error](https://github.com/IntelRealSense/librealsense/issues/12307): `Error finding RealSense cameras: failed to set power state`, this can be solved by running the same command with `sudo` permissions. Note that using RealSense cameras in `macOS` is unstable. -## Use Cameras +`ZMQCamera` and `Reachy2Camera` do not support auto-discovery. They must be configured manually by providing their network address and port or robot SDK settings. -Below are two examples, demonstrating how to work with the API. +## Use cameras -- **Asynchronous frame capture** using an OpenCV-based camera +### Frame access modes + +All camera classes implement three access modes for capturing frames: + +| Method | Behavior | Blocks? | Best For | +| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------- | ---------------------------------------- | +| `read()` | Waits for the camera hardware to return a frame. May block for a long time depending on the camera and SDK. | Yes | Simple scripts, sequential capture | +| `async_read(timeout_ms)` | Returns the latest unconsumed frame from background thread. Blocks only if buffer is empty, up to `timeout_ms`. Raises `TimeoutError` if no frame arrives. | With a timeout | Control loops synchronized to camera FPS | +| `read_latest(max_age_ms)` | Peeks at the most recent frame in buffer (may be stale). Raises `TimeoutError` if frame is older than `max_age_ms`. | No | UI visualization, logging, monitoring | + +### Usage examples + +The following examples show how to use the camera API to configure and capture frames from different camera types. + +- **Blocking and non-blocking frame capture** using an OpenCV-based camera - **Color and depth capture** using an Intel RealSense camera +> [!WARNING] +> Failing to cleanly disconnect cameras can cause resource leaks. Use the context manager protocol to ensure automatic cleanup: +> +> ```python +> with OpenCVCamera(config) as camera: +> ... +> ``` +> +> You can also call `connect()` and `disconnect()` manually, but always use a `finally` block for the latter. + @@ -60,16 +94,30 @@ config = OpenCVCameraConfig( ) # Instantiate and connect an `OpenCVCamera`, performing a warm-up read (default). -camera = OpenCVCamera(config) -camera.connect() +with OpenCVCamera(config) as camera: + + # Read a frame synchronously — blocks until hardware delivers a new frame + frame = camera.read() + print(f"read() call returned frame with shape:", frame.shape) + + # Read a frame asynchronously with a timeout — returns the latest unconsumed frame or waits up to timeout_ms for a new one + try: + for i in range(10): + frame = camera.async_read(timeout_ms=200) + print(f"async_read call returned frame {i} with shape:", frame.shape) + except TimeoutError as e: + print(f"No frame received within timeout: {e}") + + # Instantly return a frame - returns the most recent frame captured by the camera + try: + initial_frame = camera.read_latest(max_age_ms=1000) + for i in range(10): + frame = camera.read_latest(max_age_ms=1000) + print(f"read_latest call returned frame {i} with shape:", frame.shape) + print(f"Was a new frame received by the camera? {not (initial_frame == frame).any()}") + except TimeoutError as e: + print(f"Frame too old: {e}") -# Read frames asynchronously in a loop via `async_read(timeout_ms)` -try: - for i in range(10): - frame = camera.async_read(timeout_ms=200) - print(f"Async frame {i} shape:", frame.shape) -finally: - camera.disconnect() ``` @@ -111,10 +159,10 @@ finally: -## Use your phone +## Use your phone's camera - + To use your iPhone as a camera on macOS, enable the Continuity Camera feature: @@ -124,83 +172,49 @@ To use your iPhone as a camera on macOS, enable the Continuity Camera feature: For more details, visit [Apple support](https://support.apple.com/en-gb/guide/mac-help/mchl77879b8a/mac). -Your iPhone should be detected automatically when running the camera setup script in the next section. - - + -If you want to use your phone as a camera on Linux, follow these steps to set up a virtual camera +If you want to use your phone as a camera using OBS, follow these steps to set up a virtual camera. -1. _Install `v4l2loopback-dkms` and `v4l-utils`_. Those packages are required to create virtual camera devices (`v4l2loopback`) and verify their settings with the `v4l2-ctl` utility from `v4l-utils`. Install them using: +1. _(Linux only) Install `v4l2loopback-dkms` and `v4l-utils`_. These packages create virtual camera devices and verify their settings. Install with: - -```python +```bash sudo apt install v4l2loopback-dkms v4l-utils ``` - -2. _Install [DroidCam](https://droidcam.app) on your phone_. This app is available for both iOS and Android. -3. _Install [OBS Studio](https://obsproject.com)_. This software will help you manage the camera feed. Install it using [Flatpak](https://flatpak.org): +2. _Install the [DroidCam app](https://droidcam.app) on your phone_. This app is available for both iOS and Android. +3. _Download and install [OBS Studio](https://obsproject.com)_. +4. _Download and install the [DroidCam OBS plugin](https://droidcam.app/obs)_. +5. _Start OBS Studio_. - -```python -flatpak install flathub com.obsproject.Studio -``` - - -4. _Install the DroidCam OBS plugin_. This plugin integrates DroidCam with OBS Studio. Install it with: - - -```python -flatpak install flathub com.obsproject.Studio.Plugin.DroidCam -``` - - -5. _Start OBS Studio_. Launch with: - - -```python -flatpak run com.obsproject.Studio -``` - - -6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480`. -7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it in. +6. _Add your phone as a source_. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480` to avoid the watermarks. +7. _Adjust resolution settings_. In OBS Studio, go to `File > Settings > Video` or `OBS > Preferences... > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it. 8. _Start virtual camera_. In OBS Studio, follow the instructions [here](https://obsproject.com/kb/virtual-camera-guide). -9. _Verify the virtual camera setup_. Use `v4l2-ctl` to list the devices: +9. _Verify the virtual camera setup and resolution_. + - **Linux**: Use `v4l2-ctl` to list devices and check resolution: + ```bash + v4l2-ctl --list-devices # find VirtualCam and note its /dev/videoX path + v4l2-ctl -d /dev/videoX --get-fmt-video # replace with your VirtualCam path + ``` + You should see `VirtualCam` listed and resolution `640x480`. + - **macOS**: Open Photo Booth or FaceTime and select "OBS Virtual Camera" as the input. + - **Windows**: The native Camera app doesn't support virtual cameras. Use a video conferencing app (Zoom, Teams) or run `lerobot-find-cameras opencv` directly to verify. - -```python -v4l2-ctl --list-devices -``` - +
+Troubleshooting -You should see an entry like: +> The virtual camera resolution is incorrect. -``` -VirtualCam (platform:v4l2loopback-000): -/dev/video1 -``` +Delete the virtual camera source and recreate it. The resolution cannot be changed after creation. -10. _Check the camera resolution_. Use `v4l2-ctl` to ensure that the virtual camera output resolution is `640x480`. Change `/dev/video1` to the port of your virtual camera from the output of `v4l2-ctl --list-devices`. +> Error reading frame in background thread for OpenCVCamera(X): OpenCVCamera(X) frame width=640 or height=480 do not match configured width=1920 or height=1080. - -```python -v4l2-ctl -d /dev/video1 --get-fmt-video -``` - +This error is caused by OBS Virtual Camera advertising a `1920x1080` resolution despite rescaling. The only fix for now is to comment out the width and height check in `_postprocess_image()`. -You should see an entry like: - -``` ->>> Format Video Capture: ->>> Width/Height : 640/480 ->>> Pixel Format : 'YUYV' (YUYV 4:2:2) -``` - -Troubleshooting: If the resolution is not correct you will have to delete the Virtual Camera port and try again as it cannot be changed. - -If everything is set up correctly, you can proceed with the rest of the tutorial. +
+ +If everything is set up correctly, your phone will appear as a standard OpenCV camera and can be used with `OpenCVCamera`. From 5c6182176f31996fb5d0c51f88a1bc59457ba7a6 Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Fri, 30 Jan 2026 16:58:13 +0100 Subject: [PATCH 074/111] fix(find zmq): adding a clearer not implemented warning for the ZMQ find_cameras method (#2879) Co-authored-by: Martino Russi <77496684+nepyope@users.noreply.github.com> --- src/lerobot/cameras/zmq/camera_zmq.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/lerobot/cameras/zmq/camera_zmq.py b/src/lerobot/cameras/zmq/camera_zmq.py index a231a582a..f29e16a28 100644 --- a/src/lerobot/cameras/zmq/camera_zmq.py +++ b/src/lerobot/cameras/zmq/camera_zmq.py @@ -166,8 +166,10 @@ class ZMQCamera(Camera): @staticmethod def find_cameras() -> list[dict[str, Any]]: - """ZMQ cameras require manual configuration (server address/port).""" - return [] + """ + Detection not implemented for ZMQ cameras. These cameras require manual configuration (server address/port). + """ + raise NotImplementedError("Camera detection is not implemented for ZMQ cameras.") def _read_from_hardware(self) -> NDArray[Any]: """ From b18cef2e260a80db6cbe2327140950964c797b46 Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Fri, 30 Jan 2026 10:29:37 -0800 Subject: [PATCH 075/111] feat(dataset): add subtask support (#2860) * add subtask * remove folder * add docs * update doc * add testing * update test * update constant naming + doc * more docs --- docs/source/_toctree.yml | 2 + docs/source/dataset_subtask.mdx | 278 +++++++++++ src/lerobot/datasets/lerobot_dataset.py | 9 + src/lerobot/datasets/utils.py | 9 + src/lerobot/processor/converters.py | 3 +- src/lerobot/processor/tokenizer_processor.py | 46 ++ src/lerobot/utils/constants.py | 3 + tests/datasets/test_subtask_dataset.py | 190 ++++++++ tests/processor/test_tokenizer_processor.py | 465 ++++++++++++++++++- 9 files changed, 1003 insertions(+), 2 deletions(-) create mode 100644 docs/source/dataset_subtask.mdx create mode 100644 tests/datasets/test_subtask_dataset.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 98417f134..d61aac9c1 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -27,6 +27,8 @@ title: Porting Large Datasets - local: using_dataset_tools title: Using the Dataset Tools + - local: dataset_subtask + title: Using Subtasks in the Dataset title: "Datasets" - sections: - local: act diff --git a/docs/source/dataset_subtask.mdx b/docs/source/dataset_subtask.mdx new file mode 100644 index 000000000..beb5d80bd --- /dev/null +++ b/docs/source/dataset_subtask.mdx @@ -0,0 +1,278 @@ +# Using Subtasks in LeRobot Datasets + +Subtask support in robotics datasets has proven effective in improving robot reasoning and understanding. Subtasks are particularly useful for: + +- **Hierarchical policies**: Building policies that include subtask predictions to visualize robot reasoning in real time +- **Reward modeling**: Helping reward models understand task progression (e.g., SARM-style stage-aware reward models) +- **Task decomposition**: Breaking down complex manipulation tasks into atomic, interpretable steps + +LeRobotDataset now supports subtasks as part of its dataset structure, alongside tasks. + +## What are Subtasks? + +While a **task** describes the overall goal (e.g., "Pick up the apple and place it in the basket"), **subtasks** break down the execution into finer-grained steps: + +1. "Approach the apple" +2. "Grasp the apple" +3. "Lift the apple" +4. "Move to basket" +5. "Release the apple" + +Each frame in the dataset can be annotated with its corresponding subtask, enabling models to learn and predict these intermediate stages. + +An overview of subtask annotation showing how frames are labeled with intermediate subtask stages + +

+ Figure: Overview of subtask annotation. +

+ +**Reference:** _Subtask-learning based for robot self-assembly in flexible collaborative assembly in manufacturing_, Original Article, Published: 19 April 2022. + +## Dataset Structure + +Subtask information is stored in the dataset metadata: + +``` +my-dataset/ +├── data/ +│ └── ... +├── meta/ +│ ├── info.json +│ ├── stats.json +│ ├── tasks.parquet +│ ├── subtasks.parquet # Subtask index → subtask string mapping +│ └── episodes/ +│ └── ... +└── videos/ + └── ... +``` + +### Subtasks Parquet File + +The `meta/subtasks.parquet` file maps subtask indices to their natural language descriptions: + +| subtask_index | subtask (index column) | +| ------------- | ---------------------- | +| 0 | "Approach the apple" | +| 1 | "Grasp the apple" | +| 2 | "Lift the apple" | +| ... | ... | + +### Frame-Level Annotations + +Each frame in the dataset can include a `subtask_index` field that references the subtasks parquet file: + +```python +# Example frame data in the parquet file +{ + "index": 42, + "timestamp": 1.4, + "episode_index": 0, + "task_index": 0, + "subtask_index": 2, # References "Lift the apple" + "observation.state": [...], + "action": [...], +} +``` + +## Annotating Datasets with Subtasks + +We provide a HuggingFace Space for easily annotating any LeRobotDataset with subtasks: + +**[https://huggingface.co/spaces/lerobot/annotate](https://huggingface.co/spaces/lerobot/annotate)** + +After completing your annotation: + +1. Click "Push to Hub" to upload your annotated dataset +2. You can also run the annotation space locally by following the instructions at [github.com/huggingface/lerobot-annotate](https://github.com/huggingface/lerobot-annotate) + +## Loading Datasets with Subtasks + +When you load a dataset with subtask annotations, the subtask information is automatically available: + +```python +from lerobot.datasets.lerobot_dataset import LeRobotDataset + +# Load a dataset with subtask annotations +dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated") + +# Access a sample +sample = dataset[100] + +# The sample includes both task and subtask information +print(sample["task"]) # "Collect the fruit" +print(sample["subtask"]) # "Grasp the apple" +print(sample["task_index"]) # tensor(0) +print(sample["subtask_index"]) # tensor(2) +``` + +### Checking for Subtask Support + +You can check if a dataset has subtask annotations: + +```python +# Check if subtasks are available +has_subtasks = ( + "subtask_index" in dataset.features + and dataset.meta.subtasks is not None +) + +if has_subtasks: + print(f"Dataset has {len(dataset.meta.subtasks)} unique subtasks") + print("Subtasks:", list(dataset.meta.subtasks.index)) +``` + +## Using Subtasks for Training + +### With the Tokenizer Processor + +The `TokenizerProcessor` automatically handles subtask tokenization for Vision-Language Action (VLA) models: + +```python +from lerobot.processor.tokenizer_processor import TokenizerProcessor +from lerobot.processor.pipeline import ProcessorPipeline + +# Create a tokenizer processor +tokenizer_processor = TokenizerProcessor( + tokenizer_name_or_path="google/paligemma-3b-pt-224", + padding="max_length", + max_length=64, +) + +# The processor will automatically tokenize subtasks if present in the batch +# and add them to the observation under: +# - "observation.subtask.tokens" +# - "observation.subtask.attention_mask" +``` + +When subtasks are available in the batch, the tokenizer processor adds: + +- `observation.subtask.tokens`: Tokenized subtask text +- `observation.subtask.attention_mask`: Attention mask for the subtask tokens + +### DataLoader with Subtasks + +```python +import torch +from lerobot.datasets.lerobot_dataset import LeRobotDataset + +dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated") + +dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=16, + shuffle=True, +) + +for batch in dataloader: + # Access subtask information in the batch + subtasks = batch["subtask"] # List of subtask strings + subtask_indices = batch["subtask_index"] # Tensor of subtask indices + + # Use for training hierarchical policies or reward models + print(f"Batch subtasks: {set(subtasks)}") +``` + +## Example Datasets with Subtask Annotations + +Try loading a dataset with subtask annotations: + +```python +from lerobot.datasets.lerobot_dataset import LeRobotDataset + +# Example dataset with subtask annotations +dataset = LeRobotDataset("jadechoghari/collect-fruit-annotated") + +# Explore the subtasks +print("Available subtasks:") +for subtask_name in dataset.meta.subtasks.index: + print(f" - {subtask_name}") + +# Get subtask distribution +subtask_counts = {} +for i in range(len(dataset)): + sample = dataset[i] + subtask = sample["subtask"] + subtask_counts[subtask] = subtask_counts.get(subtask, 0) + 1 + +print("\nSubtask distribution:") +for subtask, count in sorted(subtask_counts.items(), key=lambda x: -x[1]): + print(f" {subtask}: {count} frames") +``` + +## Use Cases + +### 1. Hierarchical Policy Training + +Train policies that predict both actions and current subtask: + +```python +class HierarchicalPolicy(nn.Module): + def __init__(self, num_subtasks): + super().__init__() + self.action_head = nn.Linear(hidden_dim, action_dim) + self.subtask_head = nn.Linear(hidden_dim, num_subtasks) + + def forward(self, observations): + features = self.encoder(observations) + actions = self.action_head(features) + subtask_logits = self.subtask_head(features) + return actions, subtask_logits +``` + +### 2. Stage-Aware Reward Modeling (SARM) + +Build reward models that understand task progression: + +```python +# SARM predicts: +# - Stage: Which subtask is being executed (discrete) +# - Progress: How far along the subtask (continuous 0-1) + +class SARMRewardModel(nn.Module): + def forward(self, observations): + features = self.encoder(observations) + stage_logits = self.stage_classifier(features) + progress = self.progress_regressor(features) + return stage_logits, progress +``` + +### 3. Progress Visualization + +Monitor robot execution by tracking subtask progression: + +```python +def visualize_execution(model, observations): + for t, obs in enumerate(observations): + action, subtask_logits = model(obs) + predicted_subtask = subtask_names[subtask_logits.argmax()] + print(f"t={t}: Executing '{predicted_subtask}'") +``` + +## API Reference + +### LeRobotDataset Properties + +| Property | Type | Description | +| --------------------------- | ---------------------- | ------------------------------------------ | +| `meta.subtasks` | `pd.DataFrame \| None` | DataFrame mapping subtask names to indices | +| `features["subtask_index"]` | `dict` | Feature spec for subtask_index if present | + +### Sample Keys + +When subtasks are available, each sample includes: + +| Key | Type | Description | +| --------------- | -------------- | ------------------------------------ | +| `subtask_index` | `torch.Tensor` | Integer index of the current subtask | +| `subtask` | `str` | Natural language subtask description | + +## Related Resources + +- [SARM Paper](https://arxiv.org/pdf/2509.25358) - Stage-Aware Reward Modeling for Long Horizon Robot Manipulation +- [LeRobot Annotate Space](https://huggingface.co/spaces/lerobot/annotate) - Interactive annotation tool +- [LeRobotDataset v3.0](./lerobot-dataset-v3) - Dataset format documentation diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 6798e7fd7..36bffa190 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -57,6 +57,7 @@ from lerobot.datasets.utils import ( load_info, load_nested_dataset, load_stats, + load_subtasks, load_tasks, update_chunk_file_indices, validate_episode_buffer, @@ -162,6 +163,7 @@ class LeRobotDatasetMetadata: self.info = load_info(self.root) check_version_compatibility(self.repo_id, self._version, CODEBASE_VERSION) self.tasks = load_tasks(self.root) + self.subtasks = load_subtasks(self.root) self.episodes = load_episodes(self.root) self.stats = load_stats(self.root) @@ -518,6 +520,7 @@ class LeRobotDatasetMetadata: _validate_feature_names(features) obj.tasks = None + obj.subtasks = None obj.episodes = None obj.stats = None obj.info = create_empty_dataset_info( @@ -1075,6 +1078,12 @@ class LeRobotDataset(torch.utils.data.Dataset): # Add task as a string task_idx = item["task_index"].item() item["task"] = self.meta.tasks.iloc[task_idx].name + + # add subtask information if available + if "subtask_index" in self.features and self.meta.subtasks is not None: + subtask_idx = item["subtask_index"].item() + item["subtask"] = self.meta.subtasks.iloc[subtask_idx].name + return item def __repr__(self): diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index ed678af6e..321ecedd5 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -60,6 +60,7 @@ VIDEO_DIR = "videos" CHUNK_FILE_PATTERN = "chunk-{chunk_index:03d}/file-{file_index:03d}" DEFAULT_TASKS_PATH = "meta/tasks.parquet" +DEFAULT_SUBTASKS_PATH = "meta/subtasks.parquet" DEFAULT_EPISODES_PATH = EPISODES_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" DEFAULT_DATA_PATH = DATA_DIR + "/" + CHUNK_FILE_PATTERN + ".parquet" DEFAULT_VIDEO_PATH = VIDEO_DIR + "/{video_key}/" + CHUNK_FILE_PATTERN + ".mp4" @@ -353,6 +354,14 @@ def load_tasks(local_dir: Path) -> pandas.DataFrame: return tasks +def load_subtasks(local_dir: Path) -> pandas.DataFrame | None: + """Load subtasks from subtasks.parquet if it exists.""" + subtasks_path = local_dir / DEFAULT_SUBTASKS_PATH + if subtasks_path.exists(): + return pd.read_parquet(subtasks_path) + return None + + def write_episodes(episodes: Dataset, local_dir: Path) -> None: """Write episode metadata to a parquet file in the LeRobot v3.0 format. This function writes episode-level metadata to a single parquet file. diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 4f9485fee..18c7b0220 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -168,11 +168,12 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]: """ pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k} task_key = {"task": batch["task"]} if "task" in batch else {} + subtask_key = {"subtask": batch["subtask"]} if "subtask" in batch else {} index_key = {"index": batch["index"]} if "index" in batch else {} task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {} episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {} - return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key} + return {**pad_keys, **task_key, **subtask_key, **index_key, **task_index_key, **episode_index_key} def create_transition( diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index 5cd1bebb0..df559555a 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -34,6 +34,8 @@ from lerobot.utils.constants import ( ACTION_TOKEN_MASK, ACTION_TOKENS, OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_SUBTASK_ATTENTION_MASK, + OBS_LANGUAGE_SUBTASK_TOKENS, OBS_LANGUAGE_TOKENS, ) from lerobot.utils.import_utils import _transformers_available @@ -139,6 +141,32 @@ class TokenizerProcessorStep(ObservationProcessorStep): return None + def get_subtask(self, transition: EnvTransition) -> list[str] | None: + """ + Extracts the subtask from the transition's complementary data. + + Args: + transition: The environment transition. + + Returns: + A list of subtask strings, or None if the subtask key is not found or the value is None. + """ + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data is None: + return None + + subtask = complementary_data.get("subtask") + if subtask is None: + return None + + # Standardize to a list of strings for the tokenizer + if isinstance(subtask, str): + return [subtask] + elif isinstance(subtask, list) and all(isinstance(t, str) for t in subtask): + return subtask + + return None + def observation(self, observation: RobotObservation) -> RobotObservation: """ Tokenizes the task description and adds it to the observation dictionary. @@ -176,6 +204,24 @@ class TokenizerProcessorStep(ObservationProcessorStep): new_observation[OBS_LANGUAGE_TOKENS] = tokenized_prompt["input_ids"] new_observation[OBS_LANGUAGE_ATTENTION_MASK] = tokenized_prompt["attention_mask"].to(dtype=torch.bool) + # Tokenize subtask if available + subtask = self.get_subtask(self.transition) + if subtask is not None: + tokenized_subtask = self._tokenize_text(subtask) + + # Move new tokenized tensors to the detected device + if target_device is not None: + tokenized_subtask = { + k: v.to(target_device) if isinstance(v, torch.Tensor) else v + for k, v in tokenized_subtask.items() + } + + # Add tokenized subtask to the observation + new_observation[OBS_LANGUAGE_SUBTASK_TOKENS] = tokenized_subtask["input_ids"] + new_observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] = tokenized_subtask["attention_mask"].to( + dtype=torch.bool + ) + return new_observation def _detect_device(self, transition: EnvTransition) -> torch.device | None: diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index 43a61b4f7..ecd54844c 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -26,6 +26,9 @@ OBS_IMAGES = OBS_IMAGE + "s" OBS_LANGUAGE = OBS_STR + ".language" OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens" OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" +OBS_LANGUAGE_SUBTASK = OBS_STR + ".subtask" +OBS_LANGUAGE_SUBTASK_TOKENS = OBS_LANGUAGE_SUBTASK + ".tokens" +OBS_LANGUAGE_SUBTASK_ATTENTION_MASK = OBS_LANGUAGE_SUBTASK + ".attention_mask" ACTION = "action" ACTION_PREFIX = ACTION + "." diff --git a/tests/datasets/test_subtask_dataset.py b/tests/datasets/test_subtask_dataset.py new file mode 100644 index 000000000..f80a6c72d --- /dev/null +++ b/tests/datasets/test_subtask_dataset.py @@ -0,0 +1,190 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for subtask functionality in LeRobotDataset. + +These tests verify that: +- Subtask information is correctly loaded from datasets that have subtask data +- The __getitem__ method correctly adds subtask strings to returned items +- Subtask handling gracefully handles missing data +""" + +import pandas as pd +import pytest +import torch + +from lerobot.datasets.lerobot_dataset import LeRobotDataset + + +class TestSubtaskDataset: + """Tests for subtask handling in LeRobotDataset.""" + + @pytest.fixture + def subtask_dataset(self): + """Load the test subtask dataset from the hub.""" + # Use lerobot/pusht-subtask dataset with episode 1 + return LeRobotDataset( + repo_id="lerobot/pusht-subtask", + episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + ) + + def test_subtask_dataset_loads(self, subtask_dataset): + """Test that the subtask dataset loads successfully.""" + assert subtask_dataset is not None + assert len(subtask_dataset) > 0 + + def test_subtask_metadata_loaded(self, subtask_dataset): + """Test that subtask metadata is loaded when present in dataset.""" + # The dataset should have subtasks metadata loaded + assert subtask_dataset.meta.subtasks is not None + assert isinstance(subtask_dataset.meta.subtasks, pd.DataFrame) + + def test_subtask_index_in_features(self, subtask_dataset): + """Test that subtask_index is a feature when dataset has subtasks.""" + assert "subtask_index" in subtask_dataset.features + + def test_getitem_returns_subtask_string(self, subtask_dataset): + """Test that __getitem__ correctly adds subtask string to returned item.""" + item = subtask_dataset[0] + + # Subtask should be present in the returned item + assert "subtask" in item + assert isinstance(item["subtask"], str) + assert len(item["subtask"]) > 0 # Should not be empty + + def test_getitem_has_subtask_index(self, subtask_dataset): + """Test that __getitem__ includes subtask_index.""" + item = subtask_dataset[0] + + assert "subtask_index" in item + assert isinstance(item["subtask_index"], torch.Tensor) + + def test_subtask_index_maps_to_valid_subtask(self, subtask_dataset): + """Test that subtask_index correctly maps to a subtask in metadata.""" + item = subtask_dataset[0] + + subtask_idx = item["subtask_index"].item() + subtask_from_metadata = subtask_dataset.meta.subtasks.iloc[subtask_idx].name + + assert item["subtask"] == subtask_from_metadata + + def test_all_items_have_subtask(self, subtask_dataset): + """Test that all items in the dataset have subtask information.""" + for i in range(min(len(subtask_dataset), 5)): # Check first 5 items + item = subtask_dataset[i] + assert "subtask" in item + assert isinstance(item["subtask"], str) + + def test_task_and_subtask_coexist(self, subtask_dataset): + """Test that both task and subtask are present in returned items.""" + item = subtask_dataset[0] + + # Both task and subtask should be present + assert "task" in item + assert "subtask" in item + assert isinstance(item["task"], str) + assert isinstance(item["subtask"], str) + + +class TestSubtaskDatasetMissing: + """Tests for graceful handling when subtask data is missing.""" + + @pytest.fixture + def dataset_without_subtasks(self, tmp_path, empty_lerobot_dataset_factory): + """Create a dataset without subtask information.""" + features = {"state": {"dtype": "float32", "shape": (2,), "names": None}} + dataset = empty_lerobot_dataset_factory(root=tmp_path / "no_subtask", features=features) + + # Add some frames and save + for _ in range(5): + dataset.add_frame({"state": torch.randn(2), "task": "Test task"}) + dataset.save_episode() + dataset.finalize() + + # Reload the dataset + return LeRobotDataset(dataset.repo_id, root=dataset.root) + + def test_no_subtask_in_features(self, dataset_without_subtasks): + """Test that subtask_index is not in features when not provided.""" + assert "subtask_index" not in dataset_without_subtasks.features + + def test_getitem_without_subtask(self, dataset_without_subtasks): + """Test that __getitem__ works when subtask is not present.""" + item = dataset_without_subtasks[0] + + # Item should still be retrievable + assert item is not None + assert "state" in item + assert "task" in item + + # Subtask should NOT be present + assert "subtask" not in item + + def test_subtasks_metadata_is_none(self, dataset_without_subtasks): + """Test that subtasks metadata is None when not present.""" + assert dataset_without_subtasks.meta.subtasks is None + + +class TestSubtaskEdgeCases: + """Edge case tests for subtask handling.""" + + def test_subtask_with_multiple_episodes(self): + """Test subtask handling with multiple episodes if available.""" + try: + dataset = LeRobotDataset( + repo_id="lerobot/pusht-subtask", + episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + ) + except Exception: + pytest.skip("Could not load test-subtask dataset") + + # Check first and last items have valid subtasks + first_item = dataset[0] + last_item = dataset[len(dataset) - 1] + + assert "subtask" in first_item + assert "subtask" in last_item + assert isinstance(first_item["subtask"], str) + assert isinstance(last_item["subtask"], str) + + def test_subtask_index_consistency(self): + """Test that same subtask_index returns same subtask string.""" + try: + dataset = LeRobotDataset( + repo_id="lerobot/pusht-subtask", + episodes=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + ) + except Exception: + pytest.skip("Could not load test-subtask dataset") + + if len(dataset) < 2: + pytest.skip("Dataset too small for this test") + + # Collect subtask_index to subtask mappings + subtask_map = {} + for i in range(min(len(dataset), 10)): + item = dataset[i] + idx = item["subtask_index"].item() + subtask = item["subtask"] + + if idx in subtask_map: + # Same index should always return same subtask + assert subtask_map[idx] == subtask, ( + f"Inconsistent subtask for index {idx}: '{subtask_map[idx]}' vs '{subtask}'" + ) + else: + subtask_map[idx] = subtask diff --git a/tests/processor/test_tokenizer_processor.py b/tests/processor/test_tokenizer_processor.py index d6f87f567..64cc8aac8 100644 --- a/tests/processor/test_tokenizer_processor.py +++ b/tests/processor/test_tokenizer_processor.py @@ -27,7 +27,14 @@ import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.processor import DataProcessorPipeline, TokenizerProcessorStep, TransitionKey from lerobot.processor.converters import create_transition, identity_transition -from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_LANGUAGE, OBS_STATE +from lerobot.utils.constants import ( + ACTION, + OBS_IMAGE, + OBS_LANGUAGE, + OBS_LANGUAGE_SUBTASK_ATTENTION_MASK, + OBS_LANGUAGE_SUBTASK_TOKENS, + OBS_STATE, +) from tests.utils import require_package @@ -1038,3 +1045,459 @@ def test_simulated_accelerate_scenario(): # MockTokenizer squeezes single-item batches, so shape is (max_length,) not (1, max_length) assert tokens.shape == (10,) # MockTokenizer behavior for single string in list assert attention_mask.shape == (10,) + + +# ============================================================================= +# Tests for get_subtask method +# ============================================================================= + + +@require_package("transformers") +def test_get_subtask_missing_key(): + """Test get_subtask returns None when subtask key is missing from complementary_data.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task"}, # No "subtask" key + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_none_value(): + """Test get_subtask returns None when subtask value is None.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": None}, + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_none_complementary_data(): + """Test get_subtask returns None when complementary_data is None.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data=None, # No complementary data + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_string(): + """Test get_subtask returns list with single string when subtask is a string.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "pick up the cube"}, + ) + + result = processor.get_subtask(transition) + assert result == ["pick up the cube"] + assert isinstance(result, list) + assert len(result) == 1 + + +@require_package("transformers") +def test_get_subtask_list_of_strings(): + """Test get_subtask returns the list when subtask is already a list of strings.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + subtask_list = ["pick up", "move to target", "place down"] + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": subtask_list}, + ) + + result = processor.get_subtask(transition) + assert result == subtask_list + assert isinstance(result, list) + assert len(result) == 3 + + +@require_package("transformers") +def test_get_subtask_unsupported_type_integer(): + """Test get_subtask returns None when subtask is an unsupported type (integer).""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": 123}, + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_unsupported_type_mixed_list(): + """Test get_subtask returns None when subtask is a list with mixed types.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": ["valid string", 123, "another string"]}, + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_unsupported_type_dict(): + """Test get_subtask returns None when subtask is a dictionary.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": {"key": "value"}}, + ) + + result = processor.get_subtask(transition) + assert result is None + + +@require_package("transformers") +def test_get_subtask_empty_string(): + """Test get_subtask with empty string returns list with empty string.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": ""}, + ) + + result = processor.get_subtask(transition) + assert result == [""] + + +@require_package("transformers") +def test_get_subtask_empty_list(): + """Test get_subtask with empty list returns empty list.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": []}, + ) + + result = processor.get_subtask(transition) + assert result == [] + + +# ============================================================================= +# Tests for subtask tokenization in observation method +# ============================================================================= + + +@require_package("transformers") +def test_subtask_tokenization_when_present(): + """Test that subtask is tokenized and added to observation when present.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "pick up the red cube"}, + ) + + result = processor(transition) + + # Check that subtask tokens were added to observation + observation = result[TransitionKey.OBSERVATION] + assert OBS_LANGUAGE_SUBTASK_TOKENS in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation + + # Check token structure + subtask_tokens = observation[OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + assert isinstance(subtask_tokens, torch.Tensor) + assert isinstance(subtask_attention_mask, torch.Tensor) + assert subtask_tokens.shape == (8,) + assert subtask_attention_mask.shape == (8,) + assert subtask_attention_mask.dtype == torch.bool + + +@require_package("transformers") +def test_subtask_tokenization_not_added_when_none(): + """Test that subtask tokens are NOT added to observation when subtask is None.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task"}, # No subtask + ) + + result = processor(transition) + + # Check that subtask tokens were NOT added to observation + observation = result[TransitionKey.OBSERVATION] + assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation + + # But main task tokens should still be present + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + + +@require_package("transformers") +def test_subtask_tokenization_not_added_when_subtask_value_is_none(): + """Test that subtask tokens are NOT added when subtask value is explicitly None.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": None}, + ) + + result = processor(transition) + + # Check that subtask tokens were NOT added to observation + observation = result[TransitionKey.OBSERVATION] + assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation + + +@require_package("transformers") +def test_subtask_tokenization_list_of_strings(): + """Test subtask tokenization with list of strings.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": ["pick up", "place down"]}, + ) + + result = processor(transition) + + # Check that subtask tokens were added to observation + observation = result[TransitionKey.OBSERVATION] + assert OBS_LANGUAGE_SUBTASK_TOKENS in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation + + # Check token structure for batch + subtask_tokens = observation[OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + assert subtask_tokens.shape == (2, 8) # batch_size=2, seq_len=8 + assert subtask_attention_mask.shape == (2, 8) + + +@require_package("transformers") +def test_subtask_tokenization_device_cpu(): + """Test that subtask tokens are on CPU when other tensors are on CPU.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Create transition with CPU tensors + observation = {OBS_STATE: torch.randn(10)} # CPU tensor + action = torch.randn(5) # CPU tensor + transition = create_transition( + observation=observation, + action=action, + complementary_data={"task": "main task", "subtask": "pick up cube"}, + ) + + result = processor(transition) + + # Check that subtask tokens are on CPU + subtask_tokens = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_attention_mask = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + + assert subtask_tokens.device.type == "cpu" + assert subtask_attention_mask.device.type == "cpu" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@require_package("transformers") +def test_subtask_tokenization_device_cuda(): + """Test that subtask tokens are moved to CUDA when other tensors are on CUDA.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + # Create transition with CUDA tensors + observation = {OBS_STATE: torch.randn(10).cuda()} # CUDA tensor + action = torch.randn(5).cuda() # CUDA tensor + transition = create_transition( + observation=observation, + action=action, + complementary_data={"task": "main task", "subtask": "pick up cube"}, + ) + + result = processor(transition) + + # Check that subtask tokens are on CUDA + subtask_tokens = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_attention_mask = result[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + + assert subtask_tokens.device.type == "cuda" + assert subtask_attention_mask.device.type == "cuda" + + +@require_package("transformers") +def test_subtask_tokenization_preserves_other_observation_data(): + """Test that subtask tokenization preserves other observation data.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + original_state = torch.tensor([1.0, 2.0, 3.0]) + transition = create_transition( + observation={"state": original_state.clone()}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "pick up cube"}, + ) + + result = processor(transition) + observation = result[TransitionKey.OBSERVATION] + + # Check that original observation data is preserved + assert torch.equal(observation["state"], original_state) + + # Check that both task and subtask tokens are present + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + assert OBS_LANGUAGE_SUBTASK_TOKENS in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation + + +@require_package("transformers") +def test_subtask_attention_mask_dtype(): + """Test that subtask attention mask has correct dtype (bool).""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "pick up cube"}, + ) + + result = processor(transition) + observation = result[TransitionKey.OBSERVATION] + + subtask_attention_mask = observation[OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + assert subtask_attention_mask.dtype == torch.bool + + +@require_package("transformers") +def test_subtask_tokenization_deterministic(): + """Test that subtask tokenization is deterministic for the same input.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=10) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "consistent subtask"}, + ) + + result1 = processor(transition) + result2 = processor(transition) + + subtask_tokens1 = result1[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_tokens2 = result2[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_TOKENS] + subtask_mask1 = result1[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + subtask_mask2 = result2[TransitionKey.OBSERVATION][OBS_LANGUAGE_SUBTASK_ATTENTION_MASK] + + # Results should be identical + assert torch.equal(subtask_tokens1, subtask_tokens2) + assert torch.equal(subtask_mask1, subtask_mask2) + + +@require_package("transformers") +@patch("lerobot.processor.tokenizer_processor.AutoTokenizer") +def test_subtask_tokenization_integration_with_pipeline(mock_auto_tokenizer): + """Test subtask tokenization works correctly with DataProcessorPipeline.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer + + tokenizer_processor = TokenizerProcessorStep(tokenizer_name="test-tokenizer", max_length=6) + robot_processor = DataProcessorPipeline( + [tokenizer_processor], to_transition=identity_transition, to_output=identity_transition + ) + + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": "subtask instruction"}, + ) + + result = robot_processor(transition) + + # Check that observation exists and both tokenizations were applied + assert TransitionKey.OBSERVATION in result + observation = result[TransitionKey.OBSERVATION] + + # Check task tokens + assert f"{OBS_LANGUAGE}.tokens" in observation + assert f"{OBS_LANGUAGE}.attention_mask" in observation + + # Check subtask tokens + assert OBS_LANGUAGE_SUBTASK_TOKENS in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK in observation + + # Check shapes + assert observation[f"{OBS_LANGUAGE}.tokens"].shape == (6,) + assert observation[OBS_LANGUAGE_SUBTASK_TOKENS].shape == (6,) + + +@require_package("transformers") +def test_subtask_not_added_for_unsupported_types(): + """Test that subtask tokens are not added when subtask has unsupported type.""" + mock_tokenizer = MockTokenizer(vocab_size=100) + processor = TokenizerProcessorStep(tokenizer=mock_tokenizer, max_length=8) + + # Test with integer subtask + transition = create_transition( + observation={"state": torch.tensor([1.0, 2.0])}, + action=torch.tensor([0.1, 0.2]), + complementary_data={"task": "main task", "subtask": 123}, + ) + + result = processor(transition) + observation = result[TransitionKey.OBSERVATION] + + # Subtask tokens should NOT be added for unsupported types + assert OBS_LANGUAGE_SUBTASK_TOKENS not in observation + assert OBS_LANGUAGE_SUBTASK_ATTENTION_MASK not in observation + + # But main task tokens should still be present + assert f"{OBS_LANGUAGE}.tokens" in observation From 9c24a09665ffe1338a91ee7a05a0e76d10e7e3d4 Mon Sep 17 00:00:00 2001 From: Hirokazu Ishida <38597814+HiroIshida@users.noreply.github.com> Date: Tue, 3 Feb 2026 04:05:58 +0900 Subject: [PATCH 076/111] docs: update document in response to Simplify configs PR (#1596) * docs: update document input/output_shapes -> input/output_features * fix inconsistent quote (suggested by copilot reviewer) * docs: shapes => PolicyFeature * docs: relfect normalization_mapping and remove outdated --- src/lerobot/configs/policies.py | 12 ++++----- src/lerobot/policies/act/configuration_act.py | 23 +++++----------- .../diffusion/configuration_diffusion.py | 23 +++++----------- .../policies/tdmpc/configuration_tdmpc.py | 26 +++++-------------- .../policies/vqbet/configuration_vqbet.py | 23 +++++----------- 5 files changed, 34 insertions(+), 73 deletions(-) diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index 7f326b70b..44b013c29 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -45,12 +45,12 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno Args: n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the current step and additional steps going back). - input_shapes: A dictionary defining the shapes of the input data for the policy. - output_shapes: A dictionary defining the shapes of the output data for the policy. - input_normalization_modes: A dictionary with key representing the modality and the value specifies the - normalization mode to apply. - output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to - the original scale. + input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents + the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents + the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to + a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX) """ n_obs_steps: int = 1 diff --git a/src/lerobot/policies/act/configuration_act.py b/src/lerobot/policies/act/configuration_act.py index 6f6c1c4be..bd89185fd 100644 --- a/src/lerobot/policies/act/configuration_act.py +++ b/src/lerobot/policies/act/configuration_act.py @@ -28,7 +28,7 @@ class ACTConfig(PreTrainedConfig): Defaults are configured for training on bimanual Aloha tasks like "insertion" or "transfer". The parameters you will most likely need to change are the ones which depend on the environment / sensors. - Those are: `input_shapes` and 'output_shapes`. + Those are: `input_features` and `output_features`. Notes on the inputs and outputs: - Either: @@ -48,21 +48,12 @@ class ACTConfig(PreTrainedConfig): This should be no greater than the chunk size. For example, if the chunk size size 100, you may set this to 50. This would mean that the model predicts 100 steps worth of actions, runs 50 in the environment, and throws the other 50 out. - input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents - the input data name, and the value is a list indicating the dimensions of the corresponding data. - For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], - indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't - include batch dimension or temporal dimension. - output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents - the output data name, and the value is a list indicating the dimensions of the corresponding data. - For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. - Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. - input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), - and the value specifies the normalization mode to apply. The two available modes are "mean_std" - which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a - [-1, 1] range. - output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the - original scale. Note that this is also used for normalizing the training targets. + input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents + the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents + the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to + a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX) vision_backbone: Name of the torchvision resnet backbone to use for encoding images. pretrained_backbone_weights: Pretrained weights from torchvision to initialize the backbone. `None` means no pretrained weights. diff --git a/src/lerobot/policies/diffusion/configuration_diffusion.py b/src/lerobot/policies/diffusion/configuration_diffusion.py index 54569434a..8322ca337 100644 --- a/src/lerobot/policies/diffusion/configuration_diffusion.py +++ b/src/lerobot/policies/diffusion/configuration_diffusion.py @@ -30,7 +30,7 @@ class DiffusionConfig(PreTrainedConfig): Defaults are configured for training with PushT providing proprioceptive and single camera observations. The parameters you will most likely need to change are the ones which depend on the environment / sensors. - Those are: `input_shapes` and `output_shapes`. + Those are: `input_features` and `output_features`. Notes on the inputs and outputs: - "observation.state" is required as an input key. @@ -48,21 +48,12 @@ class DiffusionConfig(PreTrainedConfig): horizon: Diffusion model action prediction size as detailed in `DiffusionPolicy.select_action`. n_action_steps: The number of action steps to run in the environment for one invocation of the policy. See `DiffusionPolicy.select_action` for more details. - input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents - the input data name, and the value is a list indicating the dimensions of the corresponding data. - For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], - indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't - include batch dimension or temporal dimension. - output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents - the output data name, and the value is a list indicating the dimensions of the corresponding data. - For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. - Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. - input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), - and the value specifies the normalization mode to apply. The two available modes are "mean_std" - which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a - [-1, 1] range. - output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the - original scale. Note that this is also used for normalizing the training targets. + input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents + the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents + the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to + a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX) vision_backbone: Name of the torchvision resnet backbone to use for encoding images. crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit within the image size. If None, no cropping is done. diff --git a/src/lerobot/policies/tdmpc/configuration_tdmpc.py b/src/lerobot/policies/tdmpc/configuration_tdmpc.py index 3c1a29932..3ec493472 100644 --- a/src/lerobot/policies/tdmpc/configuration_tdmpc.py +++ b/src/lerobot/policies/tdmpc/configuration_tdmpc.py @@ -30,7 +30,7 @@ class TDMPCConfig(PreTrainedConfig): camera observations. The parameters you will most likely need to change are the ones which depend on the environment / sensors. - Those are: `input_shapes`, `output_shapes`, and perhaps `max_random_shift_ratio`. + Those are: `input_features`, `output_features`, and perhaps `max_random_shift_ratio`. Args: n_action_repeats: The number of times to repeat the action returned by the planning. (hint: Google @@ -40,24 +40,12 @@ class TDMPCConfig(PreTrainedConfig): is an alternative to using action repeats. If this is set to more than 1, then we require `n_action_repeats == 1`, `use_mpc == True` and `n_action_steps <= horizon`. Note that this approach of using multiple steps from the plan is not in the original implementation. - input_shapes: A dictionary defining the shapes of the input data for the policy. The key represents - the input data name, and the value is a list indicating the dimensions of the corresponding data. - For example, "observation.image" refers to an input from a camera with dimensions [3, 96, 96], - indicating it has three color channels and 96x96 resolution. Importantly, `input_shapes` doesn't - include batch dimension or temporal dimension. - output_shapes: A dictionary defining the shapes of the output data for the policy. The key represents - the output data name, and the value is a list indicating the dimensions of the corresponding data. - For example, "action" refers to an output shape of [14], indicating 14-dimensional actions. - Importantly, `output_shapes` doesn't include batch dimension or temporal dimension. - input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), - and the value specifies the normalization mode to apply. The two available modes are "mean_std" - which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a - [-1, 1] range. Note that here this defaults to None meaning inputs are not normalized. This is to - match the original implementation. - output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the - original scale. Note that this is also used for normalizing the training targets. NOTE: Clipping - to [-1, +1] is used during MPPI/CEM. Therefore, it is recommended that you stick with "min_max" - normalization mode here. + input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents + the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents + the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to + a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX) image_encoder_hidden_dim: Number of channels for the convolutional layers used for image encoding. state_encoder_hidden_dim: Hidden dimension for MLP used for state vector encoding. latent_dim: Observation's latent embedding dimension. diff --git a/src/lerobot/policies/vqbet/configuration_vqbet.py b/src/lerobot/policies/vqbet/configuration_vqbet.py index 44ada9f17..32906e528 100644 --- a/src/lerobot/policies/vqbet/configuration_vqbet.py +++ b/src/lerobot/policies/vqbet/configuration_vqbet.py @@ -32,7 +32,7 @@ class VQBeTConfig(PreTrainedConfig): Defaults are configured for training with PushT providing proprioceptive and single camera observations. The parameters you will most likely need to change are the ones which depend on the environment / sensors. - Those are: `input_shapes` and `output_shapes`. + Those are: `input_features` and `output_features`. Notes on the inputs and outputs: - "observation.state" is required as an input key. @@ -46,21 +46,12 @@ class VQBeTConfig(PreTrainedConfig): current step and additional steps going back). n_action_pred_token: Total number of current token and future tokens that VQ-BeT predicts. action_chunk_size: Action chunk size of each action prediction token. - input_shapes: A dictionary defining the shapes of the input data for the policy. - The key represents the input data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "observation.image" refers to an input from - a camera with dimensions [3, 96, 96], indicating it has three color channels and 96x96 resolution. - Importantly, shapes doesnt include batch dimension or temporal dimension. - output_shapes: A dictionary defining the shapes of the output data for the policy. - The key represents the output data name, and the value is a list indicating the dimensions - of the corresponding data. For example, "action" refers to an output shape of [14], indicating - 14-dimensional actions. Importantly, shapes doesnt include batch dimension or temporal dimension. - input_normalization_modes: A dictionary with key representing the modality (e.g. "observation.state"), - and the value specifies the normalization mode to apply. The two available modes are "mean_std" - which subtracts the mean and divides by the standard deviation and "min_max" which rescale in a - [-1, 1] range. - output_normalization_modes: Similar dictionary as `normalize_input_modes`, but to unnormalize to the - original scale. Note that this is also used for normalizing the training targets. + input_features: A dictionary defining the PolicyFeature of the input data for the policy. The key represents + the input data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + output_features: A dictionary defining the PolicyFeature of the output data for the policy. The key represents + the output data name, and the value is PolicyFeature, which consists of FeatureType and shape attributes. + normalization_mapping: A dictionary that maps from a str value of FeatureType (e.g., "STATE", "VISUAL") to + a corresponding NormalizationMode (e.g., NormalizationMode.MIN_MAX) vision_backbone: Name of the torchvision resnet backbone to use for encoding images. crop_shape: (H, W) shape to crop images to as a preprocessing step for the vision backbone. Must fit within the image size. If None, no cropping is done. From 14a15f90e762170209d283c3545523549841ca3d Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 2 Feb 2026 22:14:03 +0100 Subject: [PATCH 077/111] Add missing RL config options: add_ee_pose_to_observation and gripper_penalty_in_reward (#2873) * fix(RL) add missing config arguments * respond to copilot review * fix(revert penalty in reward): reverting gripper penalty addition in reward. This is already done in compute_loss_discrete_critic. --------- Co-authored-by: CarolinePascal --- src/lerobot/envs/configs.py | 1 + src/lerobot/processor/hil_processor.py | 22 ++++++++++++---------- src/lerobot/rl/gym_manipulator.py | 12 ++++++++++-- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index cd88b37bc..9c1c083a4 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -205,6 +205,7 @@ class ObservationConfig: add_joint_velocity_to_observation: bool = False add_current_to_observation: bool = False + add_ee_pose_to_observation: bool = False display_cameras: bool = False diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index 6d44ed8cb..24b5628fa 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -314,7 +314,7 @@ class TimeLimitProcessorStep(TruncatedProcessorStep): @dataclass @ProcessorStepRegistry.register("gripper_penalty_processor") -class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep): +class GripperPenaltyProcessorStep(ProcessorStep): """ Applies a penalty for inefficient gripper usage. @@ -329,26 +329,27 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep): penalty: float = -0.01 max_gripper_pos: float = 30.0 - def complementary_data(self, complementary_data: dict) -> dict: + def __call__(self, transition: EnvTransition) -> EnvTransition: """ Calculates the gripper penalty and adds it to the complementary data. Args: - complementary_data: The incoming complementary data, which should contain - raw joint positions. + transition: The incoming environment transition. Returns: - A new complementary data dictionary with the `discrete_penalty` key added. + The modified transition with the penalty added to complementary data. """ - action = self.transition.get(TransitionKey.ACTION) + new_transition = transition.copy() + action = new_transition.get(TransitionKey.ACTION) + complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) raw_joint_positions = complementary_data.get("raw_joint_positions") if raw_joint_positions is None: - return complementary_data + return new_transition current_gripper_pos = raw_joint_positions.get(GRIPPER_KEY, None) if current_gripper_pos is None: - return complementary_data + return new_transition # Gripper action is a PolicyAction at this stage gripper_action = action[-1].item() @@ -364,11 +365,12 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep): gripper_penalty = self.penalty * int(gripper_penalty_bool) - # Create new complementary data with penalty info + # Update complementary data with penalty info new_complementary_data = dict(complementary_data) new_complementary_data[DISCRETE_PENALTY_KEY] = gripper_penalty + new_transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data - return new_complementary_data + return new_transition def get_config(self) -> dict[str, Any]: """ diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index 3d58ae18f..1c1cb752f 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -412,7 +412,10 @@ def make_processors( if cfg.processor.observation.add_current_to_observation: env_pipeline_steps.append(MotorCurrentProcessorStep(robot=env.robot)) - if kinematics_solver is not None: + add_ee_pose = ( + cfg.processor.observation is not None and cfg.processor.observation.add_ee_pose_to_observation + ) + if kinematics_solver is not None and add_ee_pose: env_pipeline_steps.append( ForwardKinematicsJointsToEEObservation( kinematics=kinematics_solver, @@ -435,7 +438,12 @@ def make_processors( ) # Add gripper penalty processor if gripper config exists and enabled - if cfg.processor.gripper is not None and cfg.processor.gripper.use_gripper: + # Only add if max_gripper_pos is explicitly configured (required for normalization) + if ( + cfg.processor.gripper is not None + and cfg.processor.gripper.use_gripper + and cfg.processor.max_gripper_pos is not None + ): env_pipeline_steps.append( GripperPenaltyProcessorStep( penalty=cfg.processor.gripper.gripper_penalty, From a6370dd783c1048096b9596853beccc08a7b0bbd Mon Sep 17 00:00:00 2001 From: Iori Yanokura Date: Tue, 3 Feb 2026 22:17:04 +0900 Subject: [PATCH 078/111] fix(wandb): truncate init tags to 64-character limit (#995) --- src/lerobot/rl/wandb_utils.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/lerobot/rl/wandb_utils.py b/src/lerobot/rl/wandb_utils.py index 7b7f8a57b..ee30b75df 100644 --- a/src/lerobot/rl/wandb_utils.py +++ b/src/lerobot/rl/wandb_utils.py @@ -26,8 +26,21 @@ from lerobot.configs.train import TrainPipelineConfig from lerobot.utils.constants import PRETRAINED_MODEL_DIR -def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[str] | str: +def cfg_to_group( + cfg: TrainPipelineConfig, return_list: bool = False, truncate_tags: bool = False, max_tag_length: int = 64 +) -> list[str] | str: """Return a group name for logging. Optionally returns group name as list.""" + + def _maybe_truncate(tag: str) -> str: + """Truncate tag to max_tag_length characters if required. + + wandb rejects tags longer than 64 characters. + See: https://github.com/wandb/wandb/blob/main/wandb/sdk/wandb_settings.py + """ + if len(tag) <= max_tag_length: + return tag + return tag[:max_tag_length] + lst = [ f"policy:{cfg.policy.type}", f"seed:{cfg.seed}", @@ -36,6 +49,8 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st lst.append(f"dataset:{cfg.dataset.repo_id}") if cfg.env is not None: lst.append(f"env:{cfg.env.type}") + if truncate_tags: + lst = [_maybe_truncate(tag) for tag in lst] return lst if return_list else "-".join(lst) @@ -83,7 +98,7 @@ class WandBLogger: entity=self.cfg.entity, name=self.job_name, notes=self.cfg.notes, - tags=cfg_to_group(cfg, return_list=True), + tags=cfg_to_group(cfg, return_list=True, truncate_tags=True), dir=self.log_dir, config=cfg.to_dict(), # TODO(rcadene): try set to True From 0f392484458cb5ebca0310c0c4c47390a31c80ed Mon Sep 17 00:00:00 2001 From: jwang078 Date: Tue, 3 Feb 2026 13:19:00 -0500 Subject: [PATCH 079/111] Small docstring fix in diffusion configuration (#2847) --- src/lerobot/policies/diffusion/configuration_diffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/policies/diffusion/configuration_diffusion.py b/src/lerobot/policies/diffusion/configuration_diffusion.py index 8322ca337..8ac0920dd 100644 --- a/src/lerobot/policies/diffusion/configuration_diffusion.py +++ b/src/lerobot/policies/diffusion/configuration_diffusion.py @@ -64,7 +64,7 @@ class DiffusionConfig(PreTrainedConfig): use_group_norm: Whether to replace batch normalization with group normalization in the backbone. The group sizes are set to be about 16 (to be precise, feature_dim // 16). spatial_softmax_num_keypoints: Number of keypoints for SpatialSoftmax. - use_separate_rgb_encoders_per_camera: Whether to use a separate RGB encoder for each camera view. + use_separate_rgb_encoder_per_camera: Whether to use a separate RGB encoder for each camera view. down_dims: Feature dimension for each stage of temporal downsampling in the diffusion modeling Unet. You may provide a variable number of dimensions, therefore also controlling the degree of downsampling. From 97e7e0f9ed8831daee04a6e5f67d777f689c87e4 Mon Sep 17 00:00:00 2001 From: Reece O'Mahoney <66252930+reeceomahoney@users.noreply.github.com> Date: Thu, 5 Feb 2026 14:39:58 +0000 Subject: [PATCH 080/111] feat(datasets): improve image transform support (#2885) * improve image transform support * add tests * Add stricter transform check and extra test * improve subclass check --- src/lerobot/datasets/transforms.py | 19 ++++++++++--------- tests/datasets/test_image_transforms.py | 24 ++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/src/lerobot/datasets/transforms.py b/src/lerobot/datasets/transforms.py index beacc48d9..5240619cb 100644 --- a/src/lerobot/datasets/transforms.py +++ b/src/lerobot/datasets/transforms.py @@ -216,16 +216,17 @@ class ImageTransformsConfig: def make_transform_from_config(cfg: ImageTransformConfig): - if cfg.type == "Identity": - return v2.Identity(**cfg.kwargs) - elif cfg.type == "ColorJitter": - return v2.ColorJitter(**cfg.kwargs) - elif cfg.type == "SharpnessJitter": + if cfg.type == "SharpnessJitter": return SharpnessJitter(**cfg.kwargs) - elif cfg.type == "RandomAffine": - return v2.RandomAffine(**cfg.kwargs) - else: - raise ValueError(f"Transform '{cfg.type}' is not valid.") + + transform_cls = getattr(v2, cfg.type, None) + if isinstance(transform_cls, type) and issubclass(transform_cls, Transform): + return transform_cls(**cfg.kwargs) + + raise ValueError( + f"Transform '{cfg.type}' is not valid. It must be a class in " + f"torchvision.transforms.v2 or 'SharpnessJitter'." + ) class ImageTransforms(Transform): diff --git a/tests/datasets/test_image_transforms.py b/tests/datasets/test_image_transforms.py index 8a66ceb24..ef7e8c395 100644 --- a/tests/datasets/test_image_transforms.py +++ b/tests/datasets/test_image_transforms.py @@ -390,6 +390,30 @@ def test_sharpness_jitter_invalid_range_max_smaller(): SharpnessJitter((2.0, 0.1)) +def test_make_transform_from_config_with_v2_resize(img_tensor_factory): + img_tensor = img_tensor_factory() + tf_cfg = ImageTransformConfig(type="Resize", kwargs={"size": (32, 32)}) + tf = make_transform_from_config(tf_cfg) + assert isinstance(tf, v2.Resize) + output = tf(img_tensor) + assert output.shape[-2:] == (32, 32) + + +def test_make_transform_from_config_with_v2_identity(img_tensor_factory): + img_tensor = img_tensor_factory() + tf_cfg = ImageTransformConfig(type="Identity", kwargs={}) + tf = make_transform_from_config(tf_cfg) + assert isinstance(tf, v2.Identity) + output = tf(img_tensor) + assert output.shape == img_tensor.shape + + +def test_make_transform_from_config_invalid_type(): + tf_cfg = ImageTransformConfig(type="NotARealTransform", kwargs={}) + with pytest.raises(ValueError, match="not valid"): + make_transform_from_config(tf_cfg) + + def test_save_all_transforms(img_tensor_factory, tmp_path): img_tensor = img_tensor_factory() tf_cfg = ImageTransformsConfig(enable=True) From e14bdf57d055e85ebc8a684efd2e4b9a4c7b6a37 Mon Sep 17 00:00:00 2001 From: Reece O'Mahoney <66252930+reeceomahoney@users.noreply.github.com> Date: Mon, 9 Feb 2026 13:46:12 +0000 Subject: [PATCH 081/111] Convert tensors to scalars (#2903) Co-authored-by: Steven Palma --- src/lerobot/policies/smolvla/modeling_smolvla.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index c611e9ba2..60b968a42 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -378,16 +378,16 @@ class SmolVLAPolicy(PreTrainedPolicy): actions_is_pad = batch.get("actions_id_pad") loss_dict = {} losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) - loss_dict["losses_after_forward"] = losses.clone() + loss_dict["losses_after_forward"] = losses.clone().mean().item() if actions_is_pad is not None: in_episode_bound = ~actions_is_pad losses = losses * in_episode_bound.unsqueeze(-1) - loss_dict["losses_after_in_ep_bound"] = losses.clone() + loss_dict["losses_after_in_ep_bound"] = losses.clone().mean().item() # Remove padding losses = losses[:, :, : self.config.max_action_dim] - loss_dict["losses_after_rm_padding"] = losses.clone() + loss_dict["losses_after_rm_padding"] = losses.clone().mean().item() if reduction == "none": # Return per-sample losses (B,) by averaging over time and action dims From 489cb7b6b9a39b569aaf02ff26df6725e9b36285 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 9 Feb 2026 16:58:32 +0100 Subject: [PATCH 082/111] fix(scripts): correct can import check (#2937) --- src/lerobot/scripts/lerobot_setup_can.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lerobot/scripts/lerobot_setup_can.py b/src/lerobot/scripts/lerobot_setup_can.py index 55de74724..a31727ea4 100644 --- a/src/lerobot/scripts/lerobot_setup_can.py +++ b/src/lerobot/scripts/lerobot_setup_can.py @@ -45,7 +45,7 @@ from dataclasses import dataclass, field import draccus -from lerobot.utils.import_utils import is_package_available +from lerobot.utils.import_utils import _can_available MOTOR_NAMES = { 0x01: "joint_1", @@ -336,7 +336,7 @@ def run_speed(cfg: CANSetupConfig): @draccus.wrap() def setup_can(cfg: CANSetupConfig): - if not is_package_available("can"): + if not _can_available: print("Error: python-can not installed. Install with: pip install python-can") sys.exit(1) From cca0296cd6f0f281c5fd4628e836403628b59a05 Mon Sep 17 00:00:00 2001 From: Stepan Feduniak Date: Tue, 10 Feb 2026 13:55:11 +0100 Subject: [PATCH 083/111] fix(pipeline): use FeatureType for STATE features in Libero processor (#2888) * fix the types * pre-commit --------- Co-authored-by: Jade Choghari Co-authored-by: Steven Palma --- src/lerobot/processor/env_processor.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/lerobot/processor/env_processor.py b/src/lerobot/processor/env_processor.py index 8d42bfdb7..a77e066cf 100644 --- a/src/lerobot/processor/env_processor.py +++ b/src/lerobot/processor/env_processor.py @@ -17,7 +17,7 @@ from dataclasses import dataclass import torch -from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature from lerobot.utils.constants import OBS_IMAGES, OBS_PREFIX, OBS_STATE, OBS_STR from .pipeline import ObservationProcessorStep, ProcessorStepRegistry @@ -92,7 +92,7 @@ class LiberoProcessorStep(ObservationProcessorStep): # copy over non-STATE features for ft, feats in features.items(): - if ft != PipelineFeatureType.STATE: + if ft != FeatureType.STATE: new_features[ft] = feats.copy() # rebuild STATE features @@ -100,13 +100,11 @@ class LiberoProcessorStep(ObservationProcessorStep): # add our new flattened state state_feats[OBS_STATE] = PolicyFeature( - key=OBS_STATE, + type=FeatureType.STATE, shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)] - dtype="float32", - description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."), ) - new_features[PipelineFeatureType.STATE] = state_feats + new_features[FeatureType.STATE] = state_feats return new_features From 5eba4ce6f453c2dfe4458b037bf3612df22f81ee Mon Sep 17 00:00:00 2001 From: Aoqun Jin Date: Tue, 10 Feb 2026 21:39:17 +0800 Subject: [PATCH 084/111] Change LIBERO init_state_id when reset. (#2899) * Change LIBERO init_state_id when reset. Signed-off-by: Aoqun Jin * Change LIBERO init_state_id when reset. Signed-off-by: Aoqun Jin * pre-commit run --------- Signed-off-by: Aoqun Jin Co-authored-by: Jade Choghari Co-authored-by: Steven Palma --- src/lerobot/envs/libero.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/lerobot/envs/libero.py b/src/lerobot/envs/libero.py index 96c5cf102..d20dae8ea 100644 --- a/src/lerobot/envs/libero.py +++ b/src/lerobot/envs/libero.py @@ -112,6 +112,7 @@ class LiberoEnv(gym.Env): visualization_height: int = 480, init_states: bool = True, episode_index: int = 0, + n_envs: int = 1, camera_name_mapping: dict[str, str] | None = None, num_steps_wait: int = 10, control_mode: str = "relative", @@ -145,7 +146,9 @@ class LiberoEnv(gym.Env): self.episode_length = episode_length # Load once and keep self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None - self._init_state_id = self.episode_index # tie each sub-env to a fixed init state + self._reset_stride = n_envs # when performing a reset, append `_reset_stride` to `init_state_id`. + + self.init_state_id = self.episode_index # tie each sub-env to a fixed init state self._env = self._make_envs_task(task_suite, self.task_id) default_steps = 500 @@ -295,7 +298,8 @@ class LiberoEnv(gym.Env): self._env.seed(seed) raw_obs = self._env.reset() if self.init_states and self._init_states is not None: - raw_obs = self._env.set_init_state(self._init_states[self._init_state_id]) + raw_obs = self._env.set_init_state(self._init_states[self.init_state_id % len(self._init_states)]) + self.init_state_id += self._reset_stride # Change init_state_id when reset # After reset, objects may be unstable (slightly floating, intersecting, etc.). # Step the simulator with a no-op action for a few frames so everything settles. @@ -373,6 +377,7 @@ def _make_env_fns( init_states=init_states, episode_length=episode_length, episode_index=episode_index, + n_envs=n_envs, control_mode=control_mode, **local_kwargs, ) From d2d01399d6773427347a37be401ec6ea35fa0e15 Mon Sep 17 00:00:00 2001 From: Jai Kumaar Ratadia Date: Tue, 10 Feb 2026 14:18:32 +0000 Subject: [PATCH 085/111] docs: clarify installation steps are sequential, not optional (#2925) * docs: clarify installation steps are sequential, not optional Add intro paragraph noting conda is one path (not the only one) and number the three sections as steps so readers understand miniforge and environment setup are prerequisites, not independent choices. * Update installation guide link for LeRobot Signed-off-by: Jai Kumaar Ratadia * Fix link formatting in installation guide again Signed-off-by: Jai Kumaar Ratadia --------- Signed-off-by: Jai Kumaar Ratadia Co-authored-by: Steven Palma --- docs/source/installation.mdx | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 44d8c7034..8cc83843e 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -1,13 +1,15 @@ # Installation -## Install [`miniforge`](https://conda-forge.org/download/) +This guide uses conda (via miniforge) to manage environments. If you prefer another environment manager (e.g. `uv`, `venv`), ensure you have Python >=3.10 and ffmpeg installed with the `libsvtav1` encoder, then skip ahead to [Install LeRobot](#step-3-install-lerobot-). + +## Step 1: Install [`miniforge`](https://conda-forge.org/download/) ```bash wget "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh" bash Miniforge3-$(uname)-$(uname -m).sh ``` -## Environment Setup +## Step 2: Environment Setup Create a virtual environment with Python 3.10, using conda: @@ -38,7 +40,7 @@ conda install ffmpeg -c conda-forge > > - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`. -## Install LeRobot 🤗 +## Step 3: Install LeRobot 🤗 ### From Source From 778db19a178b2fee539934419b1475b1f411bac9 Mon Sep 17 00:00:00 2001 From: whats2000 <60466660+whats2000@users.noreply.github.com> Date: Tue, 10 Feb 2026 22:21:40 +0800 Subject: [PATCH 086/111] [Bug Fix] fix(ci): prevent runner group error on fork pushes (#2911) * fix(ci): prevent runner group error on fork pushes Add repository check to unbound_deps_tests workflow to ensure aws-general-8-plus runner group is only used on main repository, preventing 'Required runner group not found' errors on forks. * fix(ci): use gating job to prevent runner allocation on forks The previous approach failed because GitHub evaluates runs-on before if conditions. Now using a check-repo job that runs on ubuntu-latest first, and all jobs with special runners depend on it and check its output before being scheduled. * fix(ci): add gating job to full_tests to prevent runner allocation on forks Apply the same gating pattern used in unbound_deps_tests to full_tests.yml to prevent GitHub from trying to allocate custom runners when workflows run on forks. The check-repo job runs first on ubuntu-latest and all jobs with custom runners depend on it and check its output. * fix(ci): add repository check to unbound_deps_tests workflow Add 'if: github.repository == huggingface/lerobot' check to build-and-push-docker job to prevent runner group access errors on forks, matching the pattern used in nightly.yml * fix(ci): add repository check to full_tests workflow Add 'if: github.repository == huggingface/lerobot' check to build-and-push-docker and gpu-tests jobs to prevent runner group access errors on forks * refactor(ci): remove redundant check from gpu-tests job gpu-tests depends on build-and-push-docker via needs, so it will automatically skip when the parent job is skipped * refactor(ci): remove unnecessary fork check from full-tests job full-tests runs on ubuntu-latest which is available to all forks, no need to restrict it --------- Co-authored-by: Steven Palma --- .github/workflows/full_tests.yml | 8 +++++--- .github/workflows/unbound_deps_tests.yml | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml index 4dce3121a..fd5e422b3 100644 --- a/.github/workflows/full_tests.yml +++ b/.github/workflows/full_tests.yml @@ -101,9 +101,11 @@ jobs: runs-on: group: aws-general-8-plus if: | - (github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) || - github.event_name == 'push' || - github.event_name == 'workflow_dispatch' + github.repository == 'huggingface/lerobot' && ( + (github.event_name == 'pull_request_review' && github.event.review.state == 'approved' && github.event.pull_request.head.repo.fork == false) || + github.event_name == 'push' || + github.event_name == 'workflow_dispatch' + ) outputs: image_tag: ${{ steps.set_tag.outputs.image_tag }} env: diff --git a/.github/workflows/unbound_deps_tests.yml b/.github/workflows/unbound_deps_tests.yml index a75ecc121..3f4ea3316 100644 --- a/.github/workflows/unbound_deps_tests.yml +++ b/.github/workflows/unbound_deps_tests.yml @@ -91,6 +91,7 @@ jobs: name: Build and Push Docker runs-on: group: aws-general-8-plus + if: github.repository == 'huggingface/lerobot' outputs: image_tag: ${{ env.DOCKER_IMAGE_NAME }} env: From 35363c5798d129d7667c2efa43ddfa342639a35a Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 10 Feb 2026 17:35:39 +0100 Subject: [PATCH 087/111] chore(linter): ensure motors module passes MyPy type checks (#2939) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: ensure motors module passes MyPy type checks This commit fixes 62 mypy type errors in the motors module by: - Updating Protocol classes (PortHandler, PacketHandler, GroupSyncRead, GroupSyncWrite) to use class-level attribute declarations instead of __init__ body declarations - Adding missing `broadcastPing` method to PacketHandler Protocol - Fixing return type annotations (e.g., `_get_motor_model` returns str, not int) - Fixing parameter types to use `Sequence` for covariant list parameters - Fixing `Mapping` for covariant dict value types in `_normalize` - Updating method signatures to be consistent across parent and child classes (disable_torque, enable_torque, _get_half_turn_homings) - Adding explicit `int()` casts for MotorCalibration arguments - Adding explicit `return None` for functions returning Optional types - Adding type annotations for variables like `data_list: dict[int, int]` - Using `# type: ignore[method-assign]` for intentional monkeypatch - Fixing variable references (using `self.groups` instead of `groups`) Fixes #1723 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 * chore(style): pre-commit after main merge * chore(linter): solve comments * chore(linter): apply pre-commit fixes to damiao * chore(linter): more fixes to damiao --------- Co-authored-by: yurekami Co-authored-by: Claude Opus 4.5 --- pyproject.toml | 6 +- src/lerobot/motors/calibration_gui.py | 10 +- src/lerobot/motors/damiao/damiao.py | 42 ++++- src/lerobot/motors/dynamixel/dynamixel.py | 16 +- src/lerobot/motors/feetech/feetech.py | 18 +-- src/lerobot/motors/motors_bus.py | 183 +++++++++++----------- 6 files changed, 157 insertions(+), 118 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 210d70b6b..c4b1c547e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -360,9 +360,9 @@ ignore_errors = false module = "lerobot.cameras.*" ignore_errors = false -# [[tool.mypy.overrides]] -# module = "lerobot.motors.*" -# ignore_errors = false +[[tool.mypy.overrides]] +module = "lerobot.motors.*" +ignore_errors = false # [[tool.mypy.overrides]] # module = "lerobot.robots.*" diff --git a/src/lerobot/motors/calibration_gui.py b/src/lerobot/motors/calibration_gui.py index 02bba454f..3410cb28a 100644 --- a/src/lerobot/motors/calibration_gui.py +++ b/src/lerobot/motors/calibration_gui.py @@ -221,7 +221,7 @@ class RangeFinderGUI: self.bus = bus self.groups = groups if groups is not None else {"all": list(bus.motors)} - self.group_names = list(groups) + self.group_names = list(self.groups) self.current_group = self.group_names[0] if not bus.is_connected: @@ -230,18 +230,20 @@ class RangeFinderGUI: self.calibration = bus.read_calibration() self.res_table = bus.model_resolution_table self.present_cache = { - m: bus.read("Present_Position", m, normalize=False) for motors in groups.values() for m in motors + m: bus.read("Present_Position", m, normalize=False) + for motors in self.groups.values() + for m in motors } pygame.init() self.font = pygame.font.Font(None, FONT_SIZE) - label_pad = max(self.font.size(m)[0] for ms in groups.values() for m in ms) + label_pad = max(self.font.size(m)[0] for ms in self.groups.values() for m in ms) self.label_pad = label_pad width = 40 + label_pad + BAR_LEN + 6 + BTN_W + 10 + SAVE_W + 10 self.controls_bottom = 10 + SAVE_H self.base_y = self.controls_bottom + TOP_GAP - height = self.base_y + PADDING_Y * len(groups[self.current_group]) + 40 + height = self.base_y + PADDING_Y * len(self.groups[self.current_group]) + 40 self.screen = pygame.display.set_mode((width, height)) pygame.display.set_caption("Motors range finder") diff --git a/src/lerobot/motors/damiao/damiao.py b/src/lerobot/motors/damiao/damiao.py index c79f8d17e..95a9e70d1 100644 --- a/src/lerobot/motors/damiao/damiao.py +++ b/src/lerobot/motors/damiao/damiao.py @@ -211,6 +211,9 @@ class DamiaoMotorsBus(MotorsBusBase): logger.info("Starting handshake with motors...") # Drain any pending messages + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + while self.canbus.recv(timeout=0.01): pass @@ -283,6 +286,10 @@ class DamiaoMotorsBus(MotorsBusBase): recv_id = self._get_motor_recv_id(motor) data = [0xFF] * 7 + [command_byte] msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) + + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + self.canbus.send(msg) if msg := self._recv_motor_response(expected_recv_id=recv_id): self._process_response(motor_name, msg) @@ -341,6 +348,10 @@ class DamiaoMotorsBus(MotorsBusBase): recv_id = self._get_motor_recv_id(motor) data = [motor_id & 0xFF, (motor_id >> 8) & 0xFF, CAN_CMD_REFRESH, 0, 0, 0, 0, 0] msg = can.Message(arbitration_id=CAN_PARAM_ID, data=data, is_extended_id=False, is_fd=self.use_can_fd) + + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + self.canbus.send(msg) return self._recv_motor_response(expected_recv_id=recv_id) @@ -356,6 +367,10 @@ class DamiaoMotorsBus(MotorsBusBase): Returns: CAN message if received, None otherwise """ + + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + try: start_time = time.time() messages_seen = [] @@ -394,10 +409,13 @@ class DamiaoMotorsBus(MotorsBusBase): Returns: Dictionary mapping recv_id to CAN message """ - responses = {} + responses: dict[int, can.Message] = {} expected_set = set(expected_recv_ids) start_time = time.time() + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + try: while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout: # 100us poll timeout @@ -461,6 +479,9 @@ class DamiaoMotorsBus(MotorsBusBase): motor_name = self._get_motor_name(motor) motor_type = self._motor_types[motor_name] + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque) msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False, is_fd=self.use_can_fd) self.canbus.send(msg) @@ -488,6 +509,9 @@ class DamiaoMotorsBus(MotorsBusBase): recv_id_to_motor: dict[int, str] = {} + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + # Step 1: Send all MIT control commands for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items(): motor_id = self._get_motor_id(motor) @@ -656,6 +680,10 @@ class DamiaoMotorsBus(MotorsBusBase): def _batch_refresh(self, motors: list[str]) -> None: """Internal helper to refresh a list of motors and update cache.""" + + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") + # Send refresh commands for motor in motors: motor_id = self._get_motor_id(motor) @@ -678,10 +706,14 @@ class DamiaoMotorsBus(MotorsBusBase): else: logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.") - def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None: + def sync_write(self, data_name: str, values: dict[str, Value]) -> None: """ Write values to multiple motors simultaneously. Positions are always in degrees. """ + + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + if data_name in ("Kp", "Kd"): key = data_name.lower() for motor, val in values.items(): @@ -690,6 +722,8 @@ class DamiaoMotorsBus(MotorsBusBase): elif data_name == "Goal_Position": # Step 1: Send all MIT control commands recv_id_to_motor: dict[int, str] = {} + if self.canbus is None: + raise RuntimeError("CAN bus is not initialized.") for motor, value_degrees in values.items(): motor_id = self._get_motor_id(motor) motor_name = self._get_motor_name(motor) @@ -732,9 +766,9 @@ class DamiaoMotorsBus(MotorsBusBase): def record_ranges_of_motion( self, - motors: NameOrID | list[NameOrID] | None = None, + motors: str | list[str] | None = None, display_values: bool = True, - ) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: + ) -> tuple[dict[str, Value], dict[str, Value]]: """ Interactively record the min/max values of each motor in degrees. diff --git a/src/lerobot/motors/dynamixel/dynamixel.py b/src/lerobot/motors/dynamixel/dynamixel.py index c6752ee96..bca455dc5 100644 --- a/src/lerobot/motors/dynamixel/dynamixel.py +++ b/src/lerobot/motors/dynamixel/dynamixel.py @@ -181,10 +181,10 @@ class DynamixelMotorsBus(SerialMotorsBus): for motor, m in self.motors.items(): calibration[motor] = MotorCalibration( id=m.id, - drive_mode=drive_modes[motor], - homing_offset=offsets[motor], - range_min=mins[motor], - range_max=maxes[motor], + drive_mode=int(drive_modes[motor]), + homing_offset=int(offsets[motor]), + range_min=int(mins[motor]), + range_max=int(maxes[motor]), ) return calibration @@ -198,7 +198,7 @@ class DynamixelMotorsBus(SerialMotorsBus): if cache: self.calibration = calibration_dict - def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry) @@ -206,7 +206,7 @@ class DynamixelMotorsBus(SerialMotorsBus): addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable") self._write(addr, length, motor, TorqueMode.DISABLED.value, num_retry=num_retry) - def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry) @@ -235,7 +235,7 @@ class DynamixelMotorsBus(SerialMotorsBus): On Dynamixel Motors: Present_Position = Actual_Position + Homing_Offset """ - half_turn_homings = {} + half_turn_homings: dict[NameOrID, Value] = {} for motor, pos in positions.items(): model = self._get_motor_model(motor) max_res = self.model_resolution_table[model] - 1 @@ -258,6 +258,6 @@ class DynamixelMotorsBus(SerialMotorsBus): if raise_on_error: raise ConnectionError(self.packet_handler.getTxRxResult(comm)) - return + return None return {id_: data[0] for id_, data in data_list.items()} diff --git a/src/lerobot/motors/feetech/feetech.py b/src/lerobot/motors/feetech/feetech.py index 7ce3388b6..58a65310d 100644 --- a/src/lerobot/motors/feetech/feetech.py +++ b/src/lerobot/motors/feetech/feetech.py @@ -126,7 +126,7 @@ class FeetechMotorsBus(SerialMotorsBus): self.port_handler = scs.PortHandler(self.port) # HACK: monkeypatch - self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( + self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( # type: ignore[method-assign] self.port_handler, scs.PortHandler ) self.packet_handler = scs.PacketHandler(protocol_version) @@ -262,9 +262,9 @@ class FeetechMotorsBus(SerialMotorsBus): calibration[motor] = MotorCalibration( id=m.id, drive_mode=0, - homing_offset=offsets[motor], - range_min=mins[motor], - range_max=maxes[motor], + homing_offset=int(offsets[motor]), + range_min=int(mins[motor]), + range_max=int(maxes[motor]), ) return calibration @@ -284,7 +284,7 @@ class FeetechMotorsBus(SerialMotorsBus): On Feetech Motors: Present_Position = Actual_Position - Homing_Offset """ - half_turn_homings = {} + half_turn_homings: dict[NameOrID, Value] = {} for motor, pos in positions.items(): model = self._get_motor_model(motor) max_res = self.model_resolution_table[model] - 1 @@ -292,7 +292,7 @@ class FeetechMotorsBus(SerialMotorsBus): return half_turn_homings - def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry) self.write("Lock", motor, 0, num_retry=num_retry) @@ -303,7 +303,7 @@ class FeetechMotorsBus(SerialMotorsBus): addr, length = get_address(self.model_ctrl_table, model, "Lock") self._write(addr, length, motor, 0, num_retry=num_retry) - def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: for motor in self._get_motors_list(motors): self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry) self.write("Lock", motor, 1, num_retry=num_retry) @@ -334,7 +334,7 @@ class FeetechMotorsBus(SerialMotorsBus): def _broadcast_ping(self) -> tuple[dict[int, int], int]: import scservo_sdk as scs - data_list = {} + data_list: dict[int, int] = {} status_length = 6 @@ -414,7 +414,7 @@ class FeetechMotorsBus(SerialMotorsBus): if not self._is_comm_success(comm): if raise_on_error: raise ConnectionError(self.packet_handler.getTxRxResult(comm)) - return + return None ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)} if ids_errors: diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py index c04f718b6..bc3ffb7e2 100644 --- a/src/lerobot/motors/motors_bus.py +++ b/src/lerobot/motors/motors_bus.py @@ -23,6 +23,7 @@ from __future__ import annotations import abc import logging +from collections.abc import Sequence from contextlib import contextmanager from dataclasses import dataclass from enum import Enum @@ -93,7 +94,7 @@ class MotorsBusBase(abc.ABC): pass @abc.abstractmethod - def sync_write(self, data_name: str, values: Value | dict[str, Value]) -> None: + def sync_write(self, data_name: str, values: dict[str, Value]) -> None: """Write values to multiple motors.""" pass @@ -179,15 +180,16 @@ class Motor: class PortHandler(Protocol): - def __init__(self, port_name): - self.is_open: bool - self.baudrate: int - self.packet_start_time: float - self.packet_timeout: float - self.tx_time_per_byte: float - self.is_using: bool - self.port_name: str - self.ser: serial.Serial + is_open: bool + baudrate: int + packet_start_time: float + packet_timeout: float + tx_time_per_byte: float + is_using: bool + port_name: str + ser: serial.Serial + + def __init__(self, port_name: str) -> None: ... def openPort(self): ... def closePort(self): ... @@ -240,19 +242,22 @@ class PacketHandler(Protocol): def regWriteTxRx(self, port, id, address, length, data): ... def syncReadTx(self, port, start_address, data_length, param, param_length): ... def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ... + def broadcastPing(self, port): ... class GroupSyncRead(Protocol): - def __init__(self, port, ph, start_address, data_length): - self.port: str - self.ph: PortHandler - self.start_address: int - self.data_length: int - self.last_result: bool - self.is_param_changed: bool - self.param: list - self.data_dict: dict + port: str + ph: PortHandler + start_address: int + data_length: int + last_result: bool + is_param_changed: bool + param: list + data_dict: dict + def __init__( + self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int + ) -> None: ... def makeParam(self): ... def addParam(self, id): ... def removeParam(self, id): ... @@ -265,15 +270,17 @@ class GroupSyncRead(Protocol): class GroupSyncWrite(Protocol): - def __init__(self, port, ph, start_address, data_length): - self.port: str - self.ph: PortHandler - self.start_address: int - self.data_length: int - self.is_param_changed: bool - self.param: list - self.data_dict: dict + port: str + ph: PortHandler + start_address: int + data_length: int + is_param_changed: bool + param: list + data_dict: dict + def __init__( + self, port: PortHandler, ph: PacketHandler, start_address: int, data_length: int + ) -> None: ... def makeParam(self): ... def addParam(self, id, data): ... def removeParam(self, id): ... @@ -400,7 +407,7 @@ class SerialMotorsBus(MotorsBusBase): else: raise TypeError(f"'{motor}' should be int, str.") - def _get_motor_model(self, motor: NameOrID) -> int: + def _get_motor_model(self, motor: NameOrID) -> str: if isinstance(motor, str): return self.motors[motor].model elif isinstance(motor, int): @@ -408,17 +415,19 @@ class SerialMotorsBus(MotorsBusBase): else: raise TypeError(f"'{motor}' should be int, str.") - def _get_motors_list(self, motors: str | list[str] | None) -> list[str]: + def _get_motors_list(self, motors: NameOrID | Sequence[NameOrID] | None) -> list[str]: if motors is None: return list(self.motors) elif isinstance(motors, str): return [motors] - elif isinstance(motors, list): - return motors.copy() + elif isinstance(motors, int): + return [self._id_to_name(motors)] + elif isinstance(motors, Sequence): + return [m if isinstance(m, str) else self._id_to_name(m) for m in motors] else: raise TypeError(motors) - def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]: + def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> dict[int, Value]: if isinstance(values, (int | float)): return dict.fromkeys(self.ids, values) elif isinstance(values, dict): @@ -640,18 +649,19 @@ class SerialMotorsBus(MotorsBusBase): pass @abc.abstractmethod - def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + def enable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: """Enable torque on selected motors. Args: - motor (int): Same semantics as :pymeth:`disable_torque`. Defaults to `None`. + motors (int | str | list[str] | None, optional): Same semantics as :pymeth:`disable_torque`. + Defaults to `None`. num_retry (int, optional): Number of additional retry attempts on communication failure. Defaults to 0. """ pass @contextmanager - def torque_disabled(self, motors: int | str | list[str] | None = None): + def torque_disabled(self, motors: str | list[str] | None = None): """Context-manager that guarantees torque is re-enabled. This helper is useful to temporarily disable torque when configuring motors. @@ -728,24 +738,19 @@ class SerialMotorsBus(MotorsBusBase): """ pass - def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None: + def reset_calibration(self, motors: NameOrID | Sequence[NameOrID] | None = None) -> None: """Restore factory calibration for the selected motors. Homing offset is set to ``0`` and min/max position limits are set to the full usable range. The in-memory :pyattr:`calibration` is cleared. Args: - motors (NameOrID | list[NameOrID] | None, optional): Selection of motors. `None` (default) + motors (NameOrID | Sequence[NameOrID] | None, optional): Selection of motors. `None` (default) resets every motor. """ - if motors is None: - motors = list(self.motors) - elif isinstance(motors, (str | int)): - motors = [motors] - elif not isinstance(motors, list): - raise TypeError(motors) + motor_names = self._get_motors_list(motors) - for motor in motors: + for motor in motor_names: model = self._get_motor_model(motor) max_res = self.model_resolution_table[model] - 1 self.write("Homing_Offset", motor, 0, normalize=False) @@ -754,7 +759,9 @@ class SerialMotorsBus(MotorsBusBase): self.calibration = {} - def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) -> dict[NameOrID, Value]: + def set_half_turn_homings( + self, motors: NameOrID | Sequence[NameOrID] | None = None + ) -> dict[NameOrID, Value]: """Centre each motor range around its current position. The function computes and writes a homing offset such that the present position becomes exactly one @@ -764,17 +771,12 @@ class SerialMotorsBus(MotorsBusBase): motors (NameOrID | list[NameOrID] | None, optional): Motors to adjust. Defaults to all motors (`None`). Returns: - dict[NameOrID, Value]: Mapping *motor → written homing offset*. + dict[str, Value]: Mapping *motor name → written homing offset*. """ - if motors is None: - motors = list(self.motors) - elif isinstance(motors, (str | int)): - motors = [motors] - elif not isinstance(motors, list): - raise TypeError(motors) + motor_names = self._get_motors_list(motors) - self.reset_calibration(motors) - actual_positions = self.sync_read("Present_Position", motors, normalize=False) + self.reset_calibration(motor_names) + actual_positions = self.sync_read("Present_Position", motor_names, normalize=False) homing_offsets = self._get_half_turn_homings(actual_positions) for motor, offset in homing_offsets.items(): self.write("Homing_Offset", motor, offset) @@ -786,8 +788,8 @@ class SerialMotorsBus(MotorsBusBase): pass def record_ranges_of_motion( - self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True - ) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: + self, motors: NameOrID | Sequence[NameOrID] | None = None, display_values: bool = True + ) -> tuple[dict[str, Value], dict[str, Value]]: """Interactively record the min/max encoder values of each motor. Move the joints by hand (with torque disabled) while the method streams live positions. Press @@ -799,30 +801,25 @@ class SerialMotorsBus(MotorsBusBase): display_values (bool, optional): When `True` (default) a live table is printed to the console. Returns: - tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: Two dictionaries *mins* and *maxes* with the + tuple[dict[str, Value], dict[str, Value]]: Two dictionaries *mins* and *maxes* with the extreme values observed for each motor. """ - if motors is None: - motors = list(self.motors) - elif isinstance(motors, (str | int)): - motors = [motors] - elif not isinstance(motors, list): - raise TypeError(motors) + motor_names = self._get_motors_list(motors) - start_positions = self.sync_read("Present_Position", motors, normalize=False) + start_positions = self.sync_read("Present_Position", motor_names, normalize=False) mins = start_positions.copy() maxes = start_positions.copy() user_pressed_enter = False while not user_pressed_enter: - positions = self.sync_read("Present_Position", motors, normalize=False) + positions = self.sync_read("Present_Position", motor_names, normalize=False) mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()} maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()} if display_values: print("\n-------------------------------------------") print(f"{'NAME':<15} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}") - for motor in motors: + for motor in motor_names: print(f"{motor:<15} | {mins[motor]:>6} | {positions[motor]:>6} | {maxes[motor]:>6}") if enter_pressed(): @@ -830,9 +827,9 @@ class SerialMotorsBus(MotorsBusBase): if display_values and not user_pressed_enter: # Move cursor up to overwrite the previous output - move_cursor_up(len(motors) + 3) + move_cursor_up(len(motor_names) + 3) - same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]] + same_min_max = [motor for motor in motor_names if mins[motor] == maxes[motor]] if same_min_max: raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}") @@ -955,12 +952,12 @@ class SerialMotorsBus(MotorsBusBase): if raise_on_error: raise ConnectionError(self.packet_handler.getTxRxResult(comm)) else: - return + return None if self._is_error(error): if raise_on_error: raise RuntimeError(self.packet_handler.getRxPacketError(error)) else: - return + return None return model_number @@ -1007,12 +1004,13 @@ class SerialMotorsBus(MotorsBusBase): err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries." value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) - id_value = self._decode_sign(data_name, {id_: value}) + decoded = self._decode_sign(data_name, {id_: value}) if normalize and data_name in self.normalized_data: - id_value = self._normalize(id_value) + normalized = self._normalize(decoded) + return normalized[id_] - return id_value[id_] + return decoded[id_] def _read( self, @@ -1023,7 +1021,7 @@ class SerialMotorsBus(MotorsBusBase): num_retry: int = 0, raise_on_error: bool = True, err_msg: str = "", - ) -> tuple[int, int]: + ) -> tuple[int, int, int]: if length == 1: read_fn = self.packet_handler.read1ByteTxRx elif length == 2: @@ -1073,13 +1071,14 @@ class SerialMotorsBus(MotorsBusBase): model = self.motors[motor].model addr, length = get_address(self.model_ctrl_table, model, data_name) + int_value = int(value) if normalize and data_name in self.normalized_data: - value = self._unnormalize({id_: value})[id_] + int_value = self._unnormalize({id_: value})[id_] - value = self._encode_sign(data_name, {id_: value})[id_] + int_value = self._encode_sign(data_name, {id_: int_value})[id_] - err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." - self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) + err_msg = f"Failed to write '{data_name}' on {id_=} with '{int_value}' after {num_retry + 1} tries." + self._write(addr, length, id_, int_value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) def _write( self, @@ -1113,7 +1112,7 @@ class SerialMotorsBus(MotorsBusBase): def sync_read( self, data_name: str, - motors: str | list[str] | None = None, + motors: NameOrID | Sequence[NameOrID] | None = None, *, normalize: bool = True, num_retry: int = 0, @@ -1122,7 +1121,7 @@ class SerialMotorsBus(MotorsBusBase): Args: data_name (str): Register name. - motors (str | list[str] | None, optional): Motors to query. `None` (default) reads every motor. + motors (NameOrID | Sequence[NameOrID] | None, optional): Motors to query. `None` (default) reads every motor. normalize (bool, optional): Normalisation flag. Defaults to `True`. num_retry (int, optional): Retry attempts. Defaults to `0`. @@ -1143,16 +1142,17 @@ class SerialMotorsBus(MotorsBusBase): addr, length = get_address(self.model_ctrl_table, model, data_name) err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries." - ids_values, _ = self._sync_read( + raw_ids_values, _ = self._sync_read( addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg ) - ids_values = self._decode_sign(data_name, ids_values) + decoded = self._decode_sign(data_name, raw_ids_values) if normalize and data_name in self.normalized_data: - ids_values = self._normalize(ids_values) + normalized = self._normalize(decoded) + return {self._id_to_name(id_): value for id_, value in normalized.items()} - return {self._id_to_name(id_): value for id_, value in ids_values.items()} + return {self._id_to_name(id_): value for id_, value in decoded.items()} def _sync_read( self, @@ -1224,21 +1224,24 @@ class SerialMotorsBus(MotorsBusBase): num_retry (int, optional): Retry attempts. Defaults to `0`. """ - ids_values = self._get_ids_values_dict(values) - models = [self._id_to_model(id_) for id_ in ids_values] + raw_ids_values = self._get_ids_values_dict(values) + models = [self._id_to_model(id_) for id_ in raw_ids_values] if self._has_different_ctrl_tables: assert_same_address(self.model_ctrl_table, models, data_name) model = next(iter(models)) addr, length = get_address(self.model_ctrl_table, model, data_name) + int_ids_values = {id_: int(val) for id_, val in raw_ids_values.items()} if normalize and data_name in self.normalized_data: - ids_values = self._unnormalize(ids_values) + int_ids_values = self._unnormalize(raw_ids_values) - ids_values = self._encode_sign(data_name, ids_values) + int_ids_values = self._encode_sign(data_name, int_ids_values) - err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries." - self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) + err_msg = f"Failed to sync write '{data_name}' with ids_values={int_ids_values} after {num_retry + 1} tries." + self._sync_write( + addr, length, int_ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg + ) def _sync_write( self, From 1ba3975020c8079630ff7dda8fe983ad473d7c12 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 10 Feb 2026 17:49:30 +0100 Subject: [PATCH 088/111] chore: use is_connected decorators (#2948) * chore: use is_connected decorators * chore(robots): add is_connected to bi setups too --- src/lerobot/cameras/opencv/camera_opencv.py | 19 ++++++--------- .../cameras/reachy2_camera/reachy2_camera.py | 14 ++++------- .../cameras/realsense/camera_realsense.py | 23 +++++++------------ src/lerobot/cameras/zmq/camera_zmq.py | 16 +++++-------- src/lerobot/motors/damiao/damiao.py | 16 ++++--------- .../bi_openarm_follower.py | 5 ++++ .../robots/bi_so_follower/bi_so_follower.py | 5 ++++ .../openarm_follower/openarm_follower.py | 15 ++++-------- .../bi_openarm_leader/bi_openarm_leader.py | 4 ++++ .../bi_so_leader/bi_so_leader.py | 4 +++- .../openarm_leader/openarm_leader.py | 11 ++++----- 11 files changed, 57 insertions(+), 75 deletions(-) diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py index d581e1425..465ba7a1b 100644 --- a/src/lerobot/cameras/opencv/camera_opencv.py +++ b/src/lerobot/cameras/opencv/camera_opencv.py @@ -32,7 +32,8 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" os.environ["OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS"] = "0" import cv2 # type: ignore # TODO: add type stubs for OpenCV -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.errors import DeviceNotConnectedError from ..camera import Camera from ..utils import get_cv2_backend, get_cv2_rotation @@ -132,6 +133,7 @@ class OpenCVCamera(Camera): """Checks if the camera is currently connected and opened.""" return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened() + @check_if_already_connected def connect(self, warmup: bool = True) -> None: """ Connects to the OpenCV camera specified in the configuration. @@ -148,8 +150,6 @@ class OpenCVCamera(Camera): ConnectionError: If the specified camera index/path is not found or fails to open. RuntimeError: If the camera opens but fails to apply requested settings. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} is already connected.") # Use 1 thread for OpenCV operations to avoid potential conflicts or # blocking in multi-threaded applications, especially during data collection. @@ -178,6 +178,7 @@ class OpenCVCamera(Camera): logger.info(f"{self} connected.") + @check_if_not_connected def _configure_capture_settings(self) -> None: """ Applies the specified FOURCC, FPS, width, and height settings to the connected camera. @@ -197,8 +198,6 @@ class OpenCVCamera(Camera): to the requested value. DeviceNotConnectedError: If the camera is not connected. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.") # Set FOURCC first (if specified) as it can affect available FPS/resolution options if self.config.fourcc is not None: @@ -348,6 +347,7 @@ class OpenCVCamera(Camera): return frame + @check_if_not_connected def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: """ Reads a single frame synchronously from the camera. @@ -374,9 +374,6 @@ class OpenCVCamera(Camera): f"{self} read() color_mode parameter is deprecated and will be removed in future versions." ) - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -490,6 +487,7 @@ class OpenCVCamera(Camera): self.latest_timestamp = None self.new_frame_event.clear() + @check_if_not_connected def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ Reads the latest available frame asynchronously. @@ -512,8 +510,6 @@ class OpenCVCamera(Camera): TimeoutError: If no frame becomes available within the specified timeout. RuntimeError: If an unexpected error occurs. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -533,6 +529,7 @@ class OpenCVCamera(Camera): return frame + @check_if_not_connected def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: """Return the most recent frame captured immediately (Peeking). @@ -548,8 +545,6 @@ class OpenCVCamera(Camera): DeviceNotConnectedError: If the camera is not connected. RuntimeError: If the camera is connected but has not captured any frames yet. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") diff --git a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py index 5cede466d..0c1dc43d8 100644 --- a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py +++ b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py @@ -32,6 +32,7 @@ if platform.system() == "Windows" and "OPENCV_VIDEOIO_MSMF_ENABLE_HW_TRANSFORMS" import cv2 # type: ignore # TODO: add type stubs for OpenCV import numpy as np # type: ignore # TODO: add type stubs for numpy +from lerobot.utils.decorators import check_if_not_connected from lerobot.utils.import_utils import _reachy2_sdk_available if TYPE_CHECKING or _reachy2_sdk_available: @@ -123,6 +124,7 @@ class Reachy2Camera(Camera): """ raise NotImplementedError("Camera detection is not implemented for Reachy2 cameras.") + @check_if_not_connected def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: """ Reads a single frame synchronously from the camera. @@ -136,9 +138,6 @@ class Reachy2Camera(Camera): """ start_time = time.perf_counter() - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.cam_manager is None: raise DeviceNotConnectedError(f"{self} is not connected.") @@ -184,6 +183,7 @@ class Reachy2Camera(Camera): return frame + @check_if_not_connected def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ Same as read() @@ -197,11 +197,10 @@ class Reachy2Camera(Camera): TimeoutError: If no frame becomes available within the specified timeout. RuntimeError: If an unexpected error occurs. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") return self.read() + @check_if_not_connected def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: """Return the most recent frame captured immediately (Peeking). @@ -219,8 +218,6 @@ class Reachy2Camera(Camera): DeviceNotConnectedError: If the camera is not connected. RuntimeError: If the camera is connected but has not captured any frames yet. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.latest_frame is None or self.latest_timestamp is None: raise RuntimeError(f"{self} has not captured any frames yet.") @@ -233,6 +230,7 @@ class Reachy2Camera(Camera): return self.latest_frame + @check_if_not_connected def disconnect(self) -> None: """ Stops the background read thread (if running). @@ -240,8 +238,6 @@ class Reachy2Camera(Camera): Raises: DeviceNotConnectedError: If the camera is already disconnected. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} not connected.") if self.cam_manager is not None: self.cam_manager.disconnect() diff --git a/src/lerobot/cameras/realsense/camera_realsense.py b/src/lerobot/cameras/realsense/camera_realsense.py index e47f25381..d599cdce0 100644 --- a/src/lerobot/cameras/realsense/camera_realsense.py +++ b/src/lerobot/cameras/realsense/camera_realsense.py @@ -30,7 +30,8 @@ try: except Exception as e: logging.info(f"Could not import realsense: {e}") -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.errors import DeviceNotConnectedError from ..camera import Camera from ..configs import ColorMode @@ -152,6 +153,7 @@ class RealSenseCamera(Camera): """Checks if the camera pipeline is started and streams are active.""" return self.rs_pipeline is not None and self.rs_profile is not None + @check_if_already_connected def connect(self, warmup: bool = True) -> None: """ Connects to the RealSense camera specified in the configuration. @@ -169,8 +171,6 @@ class RealSenseCamera(Camera): ConnectionError: If the camera is found but fails to start the pipeline or no RealSense devices are detected at all. RuntimeError: If the pipeline starts but fails to apply requested settings. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} is already connected.") self.rs_pipeline = rs.pipeline() rs_config = rs.config() @@ -290,6 +290,7 @@ class RealSenseCamera(Camera): if self.use_depth: rs_config.enable_stream(rs.stream.depth) + @check_if_not_connected def _configure_capture_settings(self) -> None: """Sets fps, width, and height from device stream if not already configured. @@ -299,8 +300,6 @@ class RealSenseCamera(Camera): Raises: DeviceNotConnectedError: If device is not connected. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.") if self.rs_profile is None: raise RuntimeError(f"{self}: rs_profile must be initialized before use.") @@ -320,6 +319,7 @@ class RealSenseCamera(Camera): self.width, self.height = actual_width, actual_height self.capture_width, self.capture_height = actual_width, actual_height + @check_if_not_connected def read_depth(self, timeout_ms: int = 200) -> NDArray[Any]: """ Reads a single frame (depth) synchronously from the camera. @@ -345,9 +345,6 @@ class RealSenseCamera(Camera): f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}." ) - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -374,6 +371,7 @@ class RealSenseCamera(Camera): return frame + @check_if_not_connected def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 0) -> NDArray[Any]: """ Reads a single frame (color) synchronously from the camera. @@ -403,9 +401,6 @@ class RealSenseCamera(Camera): f"{self} read() timeout_ms parameter is deprecated and will be removed in future versions." ) - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -534,6 +529,7 @@ class RealSenseCamera(Camera): self.new_frame_event.clear() # NOTE(Steven): Missing implementation for depth for now + @check_if_not_connected def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ Reads the latest available frame data (color) asynchronously. @@ -556,8 +552,6 @@ class RealSenseCamera(Camera): TimeoutError: If no frame data becomes available within the specified timeout. RuntimeError: If the background thread died unexpectedly or another error occurs. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -578,6 +572,7 @@ class RealSenseCamera(Camera): return frame # NOTE(Steven): Missing implementation for depth for now + @check_if_not_connected def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: """Return the most recent (color) frame captured immediately (Peeking). @@ -593,8 +588,6 @@ class RealSenseCamera(Camera): DeviceNotConnectedError: If the camera is not connected. RuntimeError: If the camera is connected but has not captured any frames yet. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") diff --git a/src/lerobot/cameras/zmq/camera_zmq.py b/src/lerobot/cameras/zmq/camera_zmq.py index f29e16a28..16523b50a 100644 --- a/src/lerobot/cameras/zmq/camera_zmq.py +++ b/src/lerobot/cameras/zmq/camera_zmq.py @@ -34,7 +34,8 @@ import cv2 import numpy as np from numpy.typing import NDArray -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.errors import DeviceNotConnectedError from ..camera import Camera from ..configs import ColorMode @@ -104,6 +105,7 @@ class ZMQCamera(Camera): """Checks if the ZMQ socket is initialized and connected.""" return self._connected and self.context is not None and self.socket is not None + @check_if_already_connected def connect(self, warmup: bool = True) -> None: """Connect to ZMQ camera server. @@ -111,8 +113,6 @@ class ZMQCamera(Camera): warmup (bool): If True, waits for the camera to provide at least one valid frame before returning. Defaults to True. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} is already connected.") logger.info(f"Connecting to {self}...") @@ -211,6 +211,7 @@ class ZMQCamera(Camera): return frame + @check_if_not_connected def read(self, color_mode: ColorMode | None = None) -> NDArray[Any]: """ Reads a single frame synchronously from the camera. @@ -228,9 +229,6 @@ class ZMQCamera(Camera): f"{self} read() color_mode parameter is deprecated and will be removed in future versions." ) - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -301,6 +299,7 @@ class ZMQCamera(Camera): self.latest_timestamp = None self.new_frame_event.clear() + @check_if_not_connected def async_read(self, timeout_ms: float = 200) -> NDArray[Any]: """ Reads the latest available frame asynchronously. @@ -317,8 +316,6 @@ class ZMQCamera(Camera): TimeoutError: If no frame data becomes available within the specified timeout. RuntimeError: If the background thread is not running. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") @@ -335,6 +332,7 @@ class ZMQCamera(Camera): return frame + @check_if_not_connected def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: """Return the most recent frame captured immediately (Peeking). @@ -350,8 +348,6 @@ class ZMQCamera(Camera): DeviceNotConnectedError: If the camera is not connected. RuntimeError: If the camera is connected but has not captured any frames yet. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if self.thread is None or not self.thread.is_alive(): raise RuntimeError(f"{self} read thread is not running.") diff --git a/src/lerobot/motors/damiao/damiao.py b/src/lerobot/motors/damiao/damiao.py index 95a9e70d1..a454130a6 100644 --- a/src/lerobot/motors/damiao/damiao.py +++ b/src/lerobot/motors/damiao/damiao.py @@ -23,6 +23,7 @@ from copy import deepcopy from functools import cached_property from typing import TYPE_CHECKING, Any, TypedDict +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from lerobot.utils.import_utils import _can_available if TYPE_CHECKING or _can_available: @@ -36,7 +37,6 @@ else: import numpy as np -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import enter_pressed, move_cursor_up @@ -155,6 +155,7 @@ class DamiaoMotorsBus(MotorsBusBase): """Check if the CAN bus is connected.""" return self._is_connected and self.canbus is not None + @check_if_already_connected def connect(self, handshake: bool = True) -> None: """ Open the CAN bus and initialize communication. @@ -162,10 +163,6 @@ class DamiaoMotorsBus(MotorsBusBase): Args: handshake: If True, ping all motors to verify they're present """ - if self.is_connected: - raise DeviceAlreadyConnectedError( - f"{self.__class__.__name__}('{self.port}') is already connected." - ) try: # Auto-detect interface type based on port name @@ -249,6 +246,7 @@ class DamiaoMotorsBus(MotorsBusBase): ) logger.info("Handshake successful. All motors ready.") + @check_if_not_connected def disconnect(self, disable_torque: bool = True) -> None: """ Close the CAN bus connection. @@ -256,8 +254,6 @@ class DamiaoMotorsBus(MotorsBusBase): Args: disable_torque: If True, disable torque on all motors before disconnecting """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.") if disable_torque: try: @@ -586,10 +582,9 @@ class DamiaoMotorsBus(MotorsBusBase): except Exception as e: logger.warning(f"Failed to decode response from {motor}: {e}") + @check_if_not_connected def read(self, data_name: str, motor: str) -> Value: """Read a value from a single motor. Positions are always in degrees.""" - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") # Refresh motor to get latest state msg = self._refresh_motor(motor) @@ -619,6 +614,7 @@ class DamiaoMotorsBus(MotorsBusBase): raise ValueError(f"Unknown data_name: {data_name}") return mapping[data_name] + @check_if_not_connected def write( self, data_name: str, @@ -629,8 +625,6 @@ class DamiaoMotorsBus(MotorsBusBase): Write a value to a single motor. Positions are always in degrees. Can write 'Goal_Position', 'Kp', or 'Kd'. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") if data_name in ("Kp", "Kd"): self._gains[motor][data_name.lower()] = float(value) diff --git a/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py index 466eb07e5..2e3885e67 100644 --- a/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py +++ b/src/lerobot/robots/bi_openarm_follower/bi_openarm_follower.py @@ -19,6 +19,7 @@ from functools import cached_property from lerobot.processor import RobotAction, RobotObservation from lerobot.robots.openarm_follower import OpenArmFollower, OpenArmFollowerConfig +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from .config_bi_openarm_follower import BiOpenArmFollowerConfig @@ -112,6 +113,7 @@ class BiOpenArmFollower(Robot): def is_connected(self) -> bool: return self.left_arm.is_connected and self.right_arm.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: self.left_arm.connect(calibrate) self.right_arm.connect(calibrate) @@ -133,6 +135,7 @@ class BiOpenArmFollower(Robot): "Motor ID configuration is typically done via manufacturer tools for CAN motors." ) + @check_if_not_connected def get_observation(self) -> RobotObservation: obs_dict = {} @@ -146,6 +149,7 @@ class BiOpenArmFollower(Robot): return obs_dict + @check_if_not_connected def send_action( self, action: RobotAction, @@ -170,6 +174,7 @@ class BiOpenArmFollower(Robot): return {**prefixed_sent_action_left, **prefixed_sent_action_right} + @check_if_not_connected def disconnect(self): self.left_arm.disconnect() self.right_arm.disconnect() diff --git a/src/lerobot/robots/bi_so_follower/bi_so_follower.py b/src/lerobot/robots/bi_so_follower/bi_so_follower.py index 09f849772..28c58b898 100644 --- a/src/lerobot/robots/bi_so_follower/bi_so_follower.py +++ b/src/lerobot/robots/bi_so_follower/bi_so_follower.py @@ -19,6 +19,7 @@ from functools import cached_property from lerobot.processor import RobotAction, RobotObservation from lerobot.robots.so_follower import SOFollower, SOFollowerRobotConfig +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from .config_bi_so_follower import BiSOFollowerConfig @@ -96,6 +97,7 @@ class BiSOFollower(Robot): def is_connected(self) -> bool: return self.left_arm.is_connected and self.right_arm.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: self.left_arm.connect(calibrate) self.right_arm.connect(calibrate) @@ -116,6 +118,7 @@ class BiSOFollower(Robot): self.left_arm.setup_motors() self.right_arm.setup_motors() + @check_if_not_connected def get_observation(self) -> RobotObservation: obs_dict = {} @@ -129,6 +132,7 @@ class BiSOFollower(Robot): return obs_dict + @check_if_not_connected def send_action(self, action: RobotAction) -> RobotAction: # Remove "left_" prefix left_action = { @@ -148,6 +152,7 @@ class BiSOFollower(Robot): return {**prefixed_sent_action_left, **prefixed_sent_action_right} + @check_if_not_connected def disconnect(self): self.left_arm.disconnect() self.right_arm.disconnect() diff --git a/src/lerobot/robots/openarm_follower/openarm_follower.py b/src/lerobot/robots/openarm_follower/openarm_follower.py index c221afd10..d6794a226 100644 --- a/src/lerobot/robots/openarm_follower/openarm_follower.py +++ b/src/lerobot/robots/openarm_follower/openarm_follower.py @@ -23,7 +23,7 @@ from lerobot.cameras.utils import make_cameras_from_configs from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.damiao import DamiaoMotorsBus from lerobot.processor import RobotAction, RobotObservation -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..robot import Robot from ..utils import ensure_safe_goal_position @@ -119,6 +119,7 @@ class OpenArmFollower(Robot): """Check if robot is connected.""" return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: """ Connect to the robot and optionally calibrate. @@ -126,8 +127,6 @@ class OpenArmFollower(Robot): We assume that at connection time, the arms are in a safe rest position, and torque can be safely disabled to run calibration if needed. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") # Connect to CAN bus logger.info(f"Connecting arm on {self.config.port}...") @@ -219,6 +218,7 @@ class OpenArmFollower(Robot): "Motor ID configuration is typically done via manufacturer tools for CAN motors." ) + @check_if_not_connected def get_observation(self) -> RobotObservation: """ Get current observation from robot including position, velocity, and torque. @@ -228,9 +228,6 @@ class OpenArmFollower(Robot): """ start = time.perf_counter() - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - obs_dict: dict[str, Any] = {} states = self.bus.sync_read_all_states() @@ -253,6 +250,7 @@ class OpenArmFollower(Robot): return obs_dict + @check_if_not_connected def send_action( self, action: RobotAction, @@ -272,8 +270,6 @@ class OpenArmFollower(Robot): Returns: The action actually sent (potentially clipped) """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} @@ -333,10 +329,9 @@ class OpenArmFollower(Robot): return {f"{motor}.pos": val for motor, val in goal_pos.items()} + @check_if_not_connected def disconnect(self): """Disconnect from robot.""" - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") # Disconnect CAN bus self.bus.disconnect(self.config.disable_torque_on_disconnect) diff --git a/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py index c4383293f..74b0c9b83 100644 --- a/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py +++ b/src/lerobot/teleoperators/bi_openarm_leader/bi_openarm_leader.py @@ -19,6 +19,7 @@ from functools import cached_property from lerobot.processor import RobotAction from lerobot.teleoperators.openarm_leader import OpenArmLeaderConfig +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..openarm_leader import OpenArmLeader from ..teleoperator import Teleoperator @@ -88,6 +89,7 @@ class BiOpenArmLeader(Teleoperator): def is_connected(self) -> bool: return self.left_arm.is_connected and self.right_arm.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: self.left_arm.connect(calibrate) self.right_arm.connect(calibrate) @@ -109,6 +111,7 @@ class BiOpenArmLeader(Teleoperator): "Motor ID configuration is typically done via manufacturer tools for CAN motors." ) + @check_if_not_connected def get_action(self) -> RobotAction: action_dict = {} @@ -126,6 +129,7 @@ class BiOpenArmLeader(Teleoperator): # TODO: Implement force feedback raise NotImplementedError + @check_if_not_connected def disconnect(self) -> None: self.left_arm.disconnect() self.right_arm.disconnect() diff --git a/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py b/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py index 90bf2a92d..e84ac6f50 100644 --- a/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py +++ b/src/lerobot/teleoperators/bi_so_leader/bi_so_leader.py @@ -18,7 +18,7 @@ import logging from functools import cached_property from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig -from lerobot.utils.decorators import check_if_not_connected +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..so_leader import SOLeader from ..teleoperator import Teleoperator @@ -72,6 +72,7 @@ class BiSOLeader(Teleoperator): def is_connected(self) -> bool: return self.left_arm.is_connected and self.right_arm.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: self.left_arm.connect(calibrate) self.right_arm.connect(calibrate) @@ -110,6 +111,7 @@ class BiSOLeader(Teleoperator): # TODO: Implement force feedback raise NotImplementedError + @check_if_not_connected def disconnect(self) -> None: self.left_arm.disconnect() self.right_arm.disconnect() diff --git a/src/lerobot/teleoperators/openarm_leader/openarm_leader.py b/src/lerobot/teleoperators/openarm_leader/openarm_leader.py index edf4d7090..d9eaabe0f 100644 --- a/src/lerobot/teleoperators/openarm_leader/openarm_leader.py +++ b/src/lerobot/teleoperators/openarm_leader/openarm_leader.py @@ -21,7 +21,7 @@ from typing import Any from lerobot.motors import Motor, MotorCalibration, MotorNormMode from lerobot.motors.damiao import DamiaoMotorsBus from lerobot.processor import RobotAction -from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected from ..teleoperator import Teleoperator from .config_openarm_leader import OpenArmLeaderConfig @@ -84,6 +84,7 @@ class OpenArmLeader(Teleoperator): """Check if teleoperator is connected.""" return self.bus.is_connected + @check_if_already_connected def connect(self, calibrate: bool = True) -> None: """ Connect to the teleoperator. @@ -91,8 +92,6 @@ class OpenArmLeader(Teleoperator): For manual control, we disable torque after connecting so the arm can be moved by hand. """ - if self.is_connected: - raise DeviceAlreadyConnectedError(f"{self} already connected") # Connect to CAN bus logger.info(f"Connecting arm on {self.config.port}...") @@ -183,6 +182,7 @@ class OpenArmLeader(Teleoperator): "Motor ID configuration is typically done via manufacturer tools for CAN motors." ) + @check_if_not_connected def get_action(self) -> RobotAction: """ Get current action from the leader arm. @@ -193,8 +193,6 @@ class OpenArmLeader(Teleoperator): Reads all motor states (pos/vel/torque) in one CAN refresh cycle. """ start = time.perf_counter() - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") action_dict: dict[str, Any] = {} @@ -214,10 +212,9 @@ class OpenArmLeader(Teleoperator): def send_feedback(self, feedback: dict[str, float]) -> None: raise NotImplementedError("Feedback is not yet implemented for OpenArm leader.") + @check_if_not_connected def disconnect(self) -> None: """Disconnect from teleoperator.""" - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") # Disconnect CAN bus # For manual control, ensure torque is disabled before disconnecting From 3c84d271d53c9ca972cda8fce3b3f715ec813817 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Tue, 10 Feb 2026 18:40:50 +0100 Subject: [PATCH 089/111] fix(motors): use decorator to fix precommit (#2951) --- src/lerobot/motors/damiao/damiao.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/lerobot/motors/damiao/damiao.py b/src/lerobot/motors/damiao/damiao.py index a454130a6..ae619f159 100644 --- a/src/lerobot/motors/damiao/damiao.py +++ b/src/lerobot/motors/damiao/damiao.py @@ -700,14 +700,12 @@ class DamiaoMotorsBus(MotorsBusBase): else: logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.") + @check_if_not_connected def sync_write(self, data_name: str, values: dict[str, Value]) -> None: """ Write values to multiple motors simultaneously. Positions are always in degrees. """ - if not self.is_connected: - raise DeviceNotConnectedError(f"{self} is not connected.") - if data_name in ("Kp", "Kd"): key = data_name.lower() for motor, val in values.items(): From fc8a388a2538937992bd8b28bc7ac909ebd1b9a0 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 11 Feb 2026 13:57:25 +0100 Subject: [PATCH 090/111] feat(cameras): make backend configurable to the CLI (#2945) * feat(cameras): make backend configurable to the CLI * chore(cameras): address feedback * feat(Enum error messages): adding better instanciation error messages for Enum classes * chore(Enum error messages): propagating Enum error messages to all camera classes * chore(comments): removing superfluous comments * chore(format): applying ruff checks --------- Co-authored-by: CarolinePascal --- src/lerobot/cameras/__init__.py | 2 +- src/lerobot/cameras/configs.py | 23 +++++++++++++++++++ src/lerobot/cameras/opencv/camera_opencv.py | 4 ++-- .../cameras/opencv/configuration_opencv.py | 23 ++++++------------- .../configuration_reachy2_camera.py | 5 +--- .../realsense/configuration_realsense.py | 16 ++----------- src/lerobot/cameras/utils.py | 12 ---------- src/lerobot/cameras/zmq/configuration_zmq.py | 5 +--- 8 files changed, 37 insertions(+), 53 deletions(-) diff --git a/src/lerobot/cameras/__init__.py b/src/lerobot/cameras/__init__.py index 1488cd89e..cbf1f11bf 100644 --- a/src/lerobot/cameras/__init__.py +++ b/src/lerobot/cameras/__init__.py @@ -13,5 +13,5 @@ # limitations under the License. from .camera import Camera -from .configs import CameraConfig, ColorMode, Cv2Rotation +from .configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation from .utils import make_cameras_from_configs diff --git a/src/lerobot/cameras/configs.py b/src/lerobot/cameras/configs.py index 056eec314..987b74775 100644 --- a/src/lerobot/cameras/configs.py +++ b/src/lerobot/cameras/configs.py @@ -25,6 +25,10 @@ class ColorMode(str, Enum): RGB = "rgb" BGR = "bgr" + @classmethod + def _missing_(cls, value: object) -> None: + raise ValueError(f"`color_mode` is expected to be in {list(cls)}, but {value} is provided.") + class Cv2Rotation(int, Enum): NO_ROTATION = 0 @@ -32,6 +36,25 @@ class Cv2Rotation(int, Enum): ROTATE_180 = 180 ROTATE_270 = -90 + @classmethod + def _missing_(cls, value: object) -> None: + raise ValueError(f"`rotation` is expected to be in {list(cls)}, but {value} is provided.") + + +# Subset from https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html +class Cv2Backends(int, Enum): + ANY = 0 + V4L2 = 200 + DSHOW = 700 + PVAPI = 800 + ANDROID = 1000 + AVFOUNDATION = 1200 + MSMF = 1400 + + @classmethod + def _missing_(cls, value: object) -> None: + raise ValueError(f"`backend` is expected to be in {list(cls)}, but {value} is provided.") + @dataclass(kw_only=True) class CameraConfig(draccus.ChoiceRegistry, abc.ABC): # type: ignore # TODO: add type stubs for draccus diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py index 465ba7a1b..10b3f21da 100644 --- a/src/lerobot/cameras/opencv/camera_opencv.py +++ b/src/lerobot/cameras/opencv/camera_opencv.py @@ -36,7 +36,7 @@ from lerobot.utils.decorators import check_if_already_connected, check_if_not_co from lerobot.utils.errors import DeviceNotConnectedError from ..camera import Camera -from ..utils import get_cv2_backend, get_cv2_rotation +from ..utils import get_cv2_rotation from .configuration_opencv import ColorMode, OpenCVCameraConfig # NOTE(Steven): The maximum opencv device index depends on your operating system. For instance, @@ -118,7 +118,7 @@ class OpenCVCamera(Camera): self.new_frame_event: Event = Event() self.rotation: int | None = get_cv2_rotation(config.rotation) - self.backend: int = get_cv2_backend() + self.backend: int = config.backend if self.height and self.width: self.capture_width, self.capture_height = self.width, self.height diff --git a/src/lerobot/cameras/opencv/configuration_opencv.py b/src/lerobot/cameras/opencv/configuration_opencv.py index 37a42861c..8ae57fe3c 100644 --- a/src/lerobot/cameras/opencv/configuration_opencv.py +++ b/src/lerobot/cameras/opencv/configuration_opencv.py @@ -15,9 +15,9 @@ from dataclasses import dataclass from pathlib import Path -from ..configs import CameraConfig, ColorMode, Cv2Rotation +from ..configs import CameraConfig, ColorMode, Cv2Backends, Cv2Rotation -__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation"] +__all__ = ["OpenCVCameraConfig", "ColorMode", "Cv2Rotation", "Cv2Backends"] @CameraConfig.register_subclass("opencv") @@ -50,6 +50,7 @@ class OpenCVCameraConfig(CameraConfig): rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation. warmup_s: Time reading frames before returning from connect (in seconds) fourcc: FOURCC code for video format (e.g., "MJPG", "YUYV", "I420"). Defaults to None (auto-detect). + backend: OpenCV backend identifier (https://docs.opencv.org/3.4/d4/d15/group__videoio__flags__base.html). Defaults to ANY. Note: - Only 3-channel color output (RGB/BGR) is currently supported. @@ -62,22 +63,12 @@ class OpenCVCameraConfig(CameraConfig): rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION warmup_s: int = 1 fourcc: str | None = None + backend: Cv2Backends = Cv2Backends.ANY def __post_init__(self) -> None: - if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): - raise ValueError( - f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided." - ) - - if self.rotation not in ( - Cv2Rotation.NO_ROTATION, - Cv2Rotation.ROTATE_90, - Cv2Rotation.ROTATE_180, - Cv2Rotation.ROTATE_270, - ): - raise ValueError( - f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided." - ) + self.color_mode = ColorMode(self.color_mode) + self.rotation = Cv2Rotation(self.rotation) + self.backend = Cv2Backends(self.backend) if self.fourcc is not None and (not isinstance(self.fourcc, str) or len(self.fourcc) != 4): raise ValueError( diff --git a/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py index ca6db4f03..b40bfe71b 100644 --- a/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py +++ b/src/lerobot/cameras/reachy2_camera/configuration_reachy2_camera.py @@ -74,7 +74,4 @@ class Reachy2CameraConfig(CameraConfig): f"`image_type` is expected to be 'left' or 'right' for teleop camera, and 'rgb' or 'depth' for depth camera, but {self.image_type} is provided." ) - if self.color_mode not in ["rgb", "bgr"]: - raise ValueError( - f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided." - ) + self.color_mode = ColorMode(self.color_mode) diff --git a/src/lerobot/cameras/realsense/configuration_realsense.py b/src/lerobot/cameras/realsense/configuration_realsense.py index a094128bc..71b083b00 100644 --- a/src/lerobot/cameras/realsense/configuration_realsense.py +++ b/src/lerobot/cameras/realsense/configuration_realsense.py @@ -60,20 +60,8 @@ class RealSenseCameraConfig(CameraConfig): warmup_s: int = 1 def __post_init__(self) -> None: - if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): - raise ValueError( - f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided." - ) - - if self.rotation not in ( - Cv2Rotation.NO_ROTATION, - Cv2Rotation.ROTATE_90, - Cv2Rotation.ROTATE_180, - Cv2Rotation.ROTATE_270, - ): - raise ValueError( - f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided." - ) + self.color_mode = ColorMode(self.color_mode) + self.rotation = Cv2Rotation(self.rotation) values = (self.fps, self.width, self.height) if any(v is not None for v in values) and any(v is None for v in values): diff --git a/src/lerobot/cameras/utils.py b/src/lerobot/cameras/utils.py index c0e7b6284..7fb2c3bb1 100644 --- a/src/lerobot/cameras/utils.py +++ b/src/lerobot/cameras/utils.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import platform from typing import cast from lerobot.utils.import_utils import make_device_from_device_class @@ -68,14 +67,3 @@ def get_cv2_rotation(rotation: Cv2Rotation) -> int | None: return int(cv2.ROTATE_90_COUNTERCLOCKWISE) else: return None - - -def get_cv2_backend() -> int: - import cv2 - - if platform.system() == "Windows": - return int(cv2.CAP_MSMF) # Use MSMF for Windows instead of AVFOUNDATION - # elif platform.system() == "Darwin": # macOS - # return cv2.CAP_AVFOUNDATION - else: # Linux and others - return int(cv2.CAP_ANY) diff --git a/src/lerobot/cameras/zmq/configuration_zmq.py b/src/lerobot/cameras/zmq/configuration_zmq.py index 4e7732cfc..13690e14c 100644 --- a/src/lerobot/cameras/zmq/configuration_zmq.py +++ b/src/lerobot/cameras/zmq/configuration_zmq.py @@ -32,10 +32,7 @@ class ZMQCameraConfig(CameraConfig): warmup_s: int = 1 def __post_init__(self) -> None: - if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): - raise ValueError( - f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided." - ) + self.color_mode = ColorMode(self.color_mode) if self.timeout_ms <= 0: raise ValueError(f"`timeout_ms` must be positive, but {self.timeout_ms} is provided.") From 3615160d891f00a1cb8258ed8f81d327049b640e Mon Sep 17 00:00:00 2001 From: taken-yjyoon Date: Fri, 13 Feb 2026 02:13:51 +0900 Subject: [PATCH 091/111] fix(typo): Fixing wrong argparse examples in the comments (using 'True' not 'true') (#1040) Co-authored-by: juni <> --- src/lerobot/processor/pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 97ec716ff..8de376928 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -413,7 +413,7 @@ class DataProcessorPipeline(HubMixin, Generic[TInput, TOutput]): Args: save_directory: The directory where the pipeline will be saved. If None, saves to HF_LEROBOT_HOME/processors/{sanitized_pipeline_name}. - repo_id: ID of your repository on the Hub. Used only if `push_to_hub=True`. + repo_id: ID of your repository on the Hub. Used only if `push_to_hub=true`. push_to_hub: Whether or not to push your object to the Hugging Face Hub after saving it. card_kwargs: Additional arguments passed to the card template to customize the card. config_filename: The name of the JSON configuration file. If None, a name is From adebbcf090b47913d3f2e27bb27feccab174f2fc Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Thu, 12 Feb 2026 18:56:04 +0100 Subject: [PATCH 092/111] fix(dataset tools draccus): fixing draccus parsing for dataset edit operation type specification (#2949) * fix(edit dataset operation): fixing dataset tools CLI operation type specification * test(edit dataset operation): adding tests for dataset tools operation type specification * chore(format): running pre-commit * chore(backward compatibility): adding a type property in OperationConfig for backward compatibility Signed-off-by: Caroline Pascal --- src/lerobot/scripts/lerobot_edit_dataset.py | 49 +++++++------- tests/scripts/test_edit_dataset_parsing.py | 71 +++++++++++++++++++++ 2 files changed, 96 insertions(+), 24 deletions(-) create mode 100644 tests/scripts/test_edit_dataset_parsing.py diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index 2ca9c520d..7c222ac6c 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -109,11 +109,14 @@ Using JSON config file: --config_path path/to/edit_config.json """ +import abc import logging import shutil from dataclasses import dataclass from pathlib import Path +import draccus + from lerobot.configs import parser from lerobot.datasets.dataset_tools import ( convert_image_to_video_dataset, @@ -129,39 +132,46 @@ from lerobot.utils.utils import init_logging @dataclass -class DeleteEpisodesConfig: - type: str = "delete_episodes" +class OperationConfig(draccus.ChoiceRegistry, abc.ABC): + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + +@OperationConfig.register_subclass("delete_episodes") +@dataclass +class DeleteEpisodesConfig(OperationConfig): episode_indices: list[int] | None = None +@OperationConfig.register_subclass("split") @dataclass -class SplitConfig: - type: str = "split" +class SplitConfig(OperationConfig): splits: dict[str, float | list[int]] | None = None +@OperationConfig.register_subclass("merge") @dataclass -class MergeConfig: - type: str = "merge" +class MergeConfig(OperationConfig): repo_ids: list[str] | None = None +@OperationConfig.register_subclass("remove_feature") @dataclass -class RemoveFeatureConfig: - type: str = "remove_feature" +class RemoveFeatureConfig(OperationConfig): feature_names: list[str] | None = None +@OperationConfig.register_subclass("modify_tasks") @dataclass -class ModifyTasksConfig: - type: str = "modify_tasks" +class ModifyTasksConfig(OperationConfig): new_task: str | None = None episode_tasks: dict[str, str] | None = None +@OperationConfig.register_subclass("convert_image_to_video") @dataclass -class ConvertImageToVideoConfig: - type: str = "convert_image_to_video" +class ConvertImageToVideoConfig(OperationConfig): output_dir: str | None = None vcodec: str = "libsvtav1" pix_fmt: str = "yuv420p" @@ -177,14 +187,7 @@ class ConvertImageToVideoConfig: @dataclass class EditDatasetConfig: repo_id: str - operation: ( - DeleteEpisodesConfig - | SplitConfig - | MergeConfig - | RemoveFeatureConfig - | ModifyTasksConfig - | ConvertImageToVideoConfig - ) + operation: OperationConfig root: str | None = None new_repo_id: str | None = None push_to_hub: bool = False @@ -450,10 +453,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None: elif operation_type == "convert_image_to_video": handle_convert_image_to_video(cfg) else: - raise ValueError( - f"Unknown operation type: {operation_type}\n" - f"Available operations: delete_episodes, split, merge, remove_feature, modify_tasks, convert_image_to_video" - ) + available = ", ".join(OperationConfig.get_known_choices()) + raise ValueError(f"Unknown operation: {operation_type}\nAvailable operations: {available}") def main() -> None: diff --git a/tests/scripts/test_edit_dataset_parsing.py b/tests/scripts/test_edit_dataset_parsing.py new file mode 100644 index 000000000..bf7386b52 --- /dev/null +++ b/tests/scripts/test_edit_dataset_parsing.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import draccus +import pytest + +from lerobot.scripts.lerobot_edit_dataset import ( + ConvertImageToVideoConfig, + DeleteEpisodesConfig, + EditDatasetConfig, + MergeConfig, + ModifyTasksConfig, + OperationConfig, + RemoveFeatureConfig, + SplitConfig, +) + + +def parse_cfg(cli_args: list[str]) -> EditDatasetConfig: + """Helper to parse CLI args into an EditDatasetConfig via draccus.""" + return draccus.parse(EditDatasetConfig, args=cli_args) + + +class TestOperationTypeParsing: + """Test that --operation.type correctly selects the right config subclass.""" + + @pytest.mark.parametrize( + "type_name, expected_cls", + [ + ("delete_episodes", DeleteEpisodesConfig), + ("split", SplitConfig), + ("merge", MergeConfig), + ("remove_feature", RemoveFeatureConfig), + ("modify_tasks", ModifyTasksConfig), + ("convert_image_to_video", ConvertImageToVideoConfig), + ], + ) + def test_operation_type_resolves_correct_class(self, type_name, expected_cls): + cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name]) + assert isinstance(cfg.operation, expected_cls), ( + f"Expected {expected_cls.__name__}, got {type(cfg.operation).__name__}" + ) + + @pytest.mark.parametrize( + "type_name, expected_cls", + [ + ("delete_episodes", DeleteEpisodesConfig), + ("split", SplitConfig), + ("merge", MergeConfig), + ("remove_feature", RemoveFeatureConfig), + ("modify_tasks", ModifyTasksConfig), + ("convert_image_to_video", ConvertImageToVideoConfig), + ], + ) + def test_get_choice_name_roundtrips(self, type_name, expected_cls): + cfg = parse_cfg(["--repo_id", "test/repo", "--operation.type", type_name]) + resolved_name = OperationConfig.get_choice_name(type(cfg.operation)) + assert resolved_name == type_name From 6600b60e7f5cc7476ddc34beaaf0e0692f82e4b6 Mon Sep 17 00:00:00 2001 From: Pepijn <138571049+pkooij@users.noreply.github.com> Date: Fri, 13 Feb 2026 13:49:01 +0100 Subject: [PATCH 093/111] always use degrees (#2968) --- src/lerobot/robots/so_follower/config_so_follower.py | 2 +- src/lerobot/teleoperators/so_leader/config_so_leader.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lerobot/robots/so_follower/config_so_follower.py b/src/lerobot/robots/so_follower/config_so_follower.py index e9ce27123..1ee589bda 100644 --- a/src/lerobot/robots/so_follower/config_so_follower.py +++ b/src/lerobot/robots/so_follower/config_so_follower.py @@ -40,7 +40,7 @@ class SOFollowerConfig: cameras: dict[str, CameraConfig] = field(default_factory=dict) # Set to `True` for backward compatibility with previous policies/dataset - use_degrees: bool = False + use_degrees: bool = True @RobotConfig.register_subclass("so101_follower") diff --git a/src/lerobot/teleoperators/so_leader/config_so_leader.py b/src/lerobot/teleoperators/so_leader/config_so_leader.py index dd55196d7..2b4f782a7 100644 --- a/src/lerobot/teleoperators/so_leader/config_so_leader.py +++ b/src/lerobot/teleoperators/so_leader/config_so_leader.py @@ -28,7 +28,7 @@ class SOLeaderConfig: port: str # Whether to use degrees for angles - use_degrees: bool = False + use_degrees: bool = True @TeleoperatorConfig.register_subclass("so101_leader") From 51d3822d75491507561b5f11db0e62d56b342d93 Mon Sep 17 00:00:00 2001 From: masato-ka Date: Wed, 18 Feb 2026 04:09:42 +0900 Subject: [PATCH 094/111] feat(datasets): Add info operation to lerobot-edit-dataset command (#2917) * Add New featrue to lerobot_edit_datset.py that show dataset information. * Fix to draccus error when happen give only --operation.type=info * Updating test and documents regarding lerobot-edit-dataset info function. * Updating documents regarding lerobot-edit-dataset extract function. option name in document is mistake. * feat(datasets): Update to align formatting with pre-commit.(#2917) Update to align formatting by pre-commit. --------- Co-authored-by: Caroline Pascal --- docs/source/using_dataset_tools.mdx | 25 ++++++++ src/lerobot/scripts/lerobot_edit_dataset.py | 65 +++++++++++++++++++++ tests/scripts/test_edit_dataset_parsing.py | 3 + 3 files changed, 93 insertions(+) diff --git a/docs/source/using_dataset_tools.mdx b/docs/source/using_dataset_tools.mdx index 9e662604e..f7fc9be20 100644 --- a/docs/source/using_dataset_tools.mdx +++ b/docs/source/using_dataset_tools.mdx @@ -12,6 +12,7 @@ LeRobot provides several utilities for manipulating datasets: 4. **Add Features** - Add new features to a dataset 5. **Remove Features** - Remove features from a dataset 6. **Convert to Video** - Convert image-based datasets to video format for efficient storage +7. **Show the Info of Datasets** - Show the summary of datasets information such as number of episode etc. The core implementation is in `lerobot.datasets.dataset_tools`. An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`. @@ -156,6 +157,30 @@ lerobot-edit-dataset \ **Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved. +### Show the information of datasets + +Show the information of datasets such as number of episode, number of frame, File size and so on. +No change will be made to the dataset + +```bash + +# Show dataset information without feature details +lerobot-edit-dataset \ + --repo_id lerobot/pusht_image \ + --operation.type info \ + +# Show dataset information with feature details +lerobot-edit-dataset \ + --repo_id lerobot/pusht_image \ + --operation.type info \ + --operation.show_features true + +``` + +**Parameters:** + +- `parameters`: The flag to control show or no show dataset information with feature details.(default=false) + ### Push to Hub Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub: diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index 7c222ac6c..06e256fa2 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -104,6 +104,18 @@ Convert image dataset to video format and push to hub: --operation.type convert_image_to_video \ --push_to_hub true +Show dataset information: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht_image \ + --operation.type info \ + --operation.show_features true + +Show dataset information without feature details: + python -m lerobot.scripts.lerobot_edit_dataset \ + --repo_id lerobot/pusht_image \ + --operation.type info \ + --operation.show_features false + Using JSON config file: python -m lerobot.scripts.lerobot_edit_dataset \ --config_path path/to/edit_config.json @@ -112,6 +124,7 @@ Using JSON config file: import abc import logging import shutil +import sys from dataclasses import dataclass from pathlib import Path @@ -184,6 +197,13 @@ class ConvertImageToVideoConfig(OperationConfig): max_frames_per_batch: int | None = None +@OperationConfig.register_subclass("info") +@dataclass +class InfoConfig(OperationConfig): + type: str = "info" + show_features: bool = False + + @dataclass class EditDatasetConfig: repo_id: str @@ -436,6 +456,49 @@ def handle_convert_image_to_video(cfg: EditDatasetConfig) -> None: logging.info("Dataset saved locally (not pushed to hub)") +def _get_dataset_size(repo_path): + import os + + total = 0 + with os.scandir(repo_path) as it: + for entry in it: + if entry.is_file(): + total += entry.stat().st_size + elif entry.is_dir(): + total += _get_dataset_size(entry.path) + return total + + +def handle_info(cfg: EditDatasetConfig): + if not isinstance(cfg.operation, InfoConfig): + raise ValueError("Operation config must be InfoConfig") + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.root) + sys.stdout.write(f"======Info {dataset.meta.repo_id}\n") + sys.stdout.write(f"Repository ID: {dataset.meta.repo_id} \n") + sys.stdout.write(f"Total episode: {dataset.meta.total_episodes} \n") + sys.stdout.write(f"Total task: {dataset.meta.total_tasks} \n") + sys.stdout.write(f"Total frame(Actual Count): {dataset.meta.total_frames}({len(dataset)}) \n") + sys.stdout.write( + f"Average frame per episode: {dataset.meta.total_frames / dataset.meta.total_episodes:.1f}\n" + ) + sys.stdout.write( + f"Average episode time(sec): {(dataset.meta.total_frames / dataset.meta.total_episodes) / dataset.meta.fps:.1f}\n" + ) + sys.stdout.write(f"FPS: {dataset.meta.fps}\n") + + total_file_size = _get_dataset_size(dataset.root) + sys.stdout.write(f"Size: {total_file_size / (1024 * 1024):.1f} MB\n") + if cfg.operation.show_features: + import json + + feature_dump_str = json.dumps( + dataset.meta.features, ensure_ascii=False, indent=4, sort_keys=True, separators=(",", ": ") + ) + sys.stdout.write("Features:\n") + sys.stdout.write(f"{feature_dump_str}\n") + + @parser.wrap() def edit_dataset(cfg: EditDatasetConfig) -> None: operation_type = cfg.operation.type @@ -452,6 +515,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None: handle_modify_tasks(cfg) elif operation_type == "convert_image_to_video": handle_convert_image_to_video(cfg) + elif operation_type == "info": + handle_info(cfg) else: available = ", ".join(OperationConfig.get_known_choices()) raise ValueError(f"Unknown operation: {operation_type}\nAvailable operations: {available}") diff --git a/tests/scripts/test_edit_dataset_parsing.py b/tests/scripts/test_edit_dataset_parsing.py index bf7386b52..8800b92ee 100644 --- a/tests/scripts/test_edit_dataset_parsing.py +++ b/tests/scripts/test_edit_dataset_parsing.py @@ -21,6 +21,7 @@ from lerobot.scripts.lerobot_edit_dataset import ( ConvertImageToVideoConfig, DeleteEpisodesConfig, EditDatasetConfig, + InfoConfig, MergeConfig, ModifyTasksConfig, OperationConfig, @@ -46,6 +47,7 @@ class TestOperationTypeParsing: ("remove_feature", RemoveFeatureConfig), ("modify_tasks", ModifyTasksConfig), ("convert_image_to_video", ConvertImageToVideoConfig), + ("info", InfoConfig), ], ) def test_operation_type_resolves_correct_class(self, type_name, expected_cls): @@ -63,6 +65,7 @@ class TestOperationTypeParsing: ("remove_feature", RemoveFeatureConfig), ("modify_tasks", ModifyTasksConfig), ("convert_image_to_video", ConvertImageToVideoConfig), + ("info", InfoConfig), ], ) def test_get_choice_name_roundtrips(self, type_name, expected_cls): From 1c388c0002c609ca783bf42729a1e41532a1fba0 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Tue, 17 Feb 2026 23:37:46 +0100 Subject: [PATCH 095/111] (Chore) Bump upper bound for torch version (#2897) * Bump upper torch version bound * Apply suggestion from @Copilot Signed-off-by: Vladislav Sovrasov * Update ref state dicts for schedulers * Support older than 2.8 torch versions * Fix precommit --------- Signed-off-by: Vladislav Sovrasov --- pyproject.toml | 6 +++--- tests/optim/test_schedulers.py | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c4b1c547e..e5431ada3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,9 +76,9 @@ dependencies = [ "pyserial>=3.5,<4.0", "wandb>=0.24.0,<0.25.0", - "torch>=2.2.1,<2.8.0", # TODO: Bumb dependency - "torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency - "torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency + "torch>=2.2.1,<2.11.0", # TODO: Bump dependency + "torchcodec>=0.2.1,<0.11.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bump dependency + "torchvision>=0.21.0,<0.26.0", # TODO: Bump dependency "draccus==0.10.0", # TODO: Remove == "gymnasium>=1.1.1,<2.0.0", diff --git a/tests/optim/test_schedulers.py b/tests/optim/test_schedulers.py index 1e566a6ba..224613416 100644 --- a/tests/optim/test_schedulers.py +++ b/tests/optim/test_schedulers.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import torch +from packaging.version import Version from torch.optim.lr_scheduler import LambdaLR from lerobot.optim.schedulers import ( @@ -38,6 +40,10 @@ def test_diffuser_scheduler(optimizer): "last_epoch": 1, "lr_lambdas": [None], } + + if Version(torch.__version__) >= Version("2.8"): + expected_state_dict["_is_initial"] = False + assert scheduler.state_dict() == expected_state_dict @@ -56,6 +62,10 @@ def test_vqbet_scheduler(optimizer): "last_epoch": 1, "lr_lambdas": [None], } + + if Version(torch.__version__) >= Version("2.8"): + expected_state_dict["_is_initial"] = False + assert scheduler.state_dict() == expected_state_dict @@ -76,6 +86,10 @@ def test_cosine_decay_with_warmup_scheduler(optimizer): "last_epoch": 1, "lr_lambdas": [None], } + + if Version(torch.__version__) >= Version("2.8"): + expected_state_dict["_is_initial"] = False + assert scheduler.state_dict() == expected_state_dict From af036ce57e8ce2750f7fa57f4262c87a013bcdff Mon Sep 17 00:00:00 2001 From: Sota Nakamura <49087984+sotanakamura@users.noreply.github.com> Date: Wed, 18 Feb 2026 09:05:51 +0900 Subject: [PATCH 096/111] fix(scripts): serve grpc for a web viewer (#2881) * serve grpc for a web viewer * add help * remove ip detection * fix comment * pass grpc_port * fix(CLI): fixing CLI display-compressed-images argument 1/2 Co-authored-by: HUANG TZU-CHUN Signed-off-by: Caroline Pascal * fix(CLI): fixing CLI display-compressed-images argument 2/2 Co-authored-by: HUANG TZU-CHUN Signed-off-by: Caroline Pascal --------- Signed-off-by: Caroline Pascal Co-authored-by: Caroline Pascal Co-authored-by: HUANG TZU-CHUN Co-authored-by: Steven Palma --- src/lerobot/scripts/lerobot_dataset_viz.py | 37 +++++++++++++++------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index 2cd48eab8..29d64554f 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -47,16 +47,14 @@ local$ rerun lerobot_pusht_episode_0.rrd ``` - Visualize data stored on a distant machine through streaming: -(You need to forward the websocket port to the distant machine, with -`ssh -L 9087:localhost:9087 username@remote-host`) ``` distant$ lerobot-dataset-viz \ --repo-id lerobot/pusht \ --episode-index 0 \ --mode distant \ - --ws-port 9087 + --grpc-port 9876 -local$ rerun ws://localhost:9087 +local$ rerun rerun+http://IP:GRPC_PORT/proxy ``` """ @@ -75,6 +73,7 @@ import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD +from lerobot.utils.utils import init_logging def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: @@ -93,10 +92,11 @@ def visualize_dataset( num_workers: int = 0, mode: str = "local", web_port: int = 9090, - ws_port: int = 9087, + grpc_port: int = 9876, save: bool = False, output_dir: Path | None = None, display_compressed_images: bool = False, + **kwargs, ) -> Path | None: if save: assert output_dir is not None, ( @@ -126,7 +126,9 @@ def visualize_dataset( gc.collect() if mode == "distant": - rr.serve_web_viewer(open_browser=False, web_port=web_port) + server_uri = rr.serve_grpc(grpc_port=grpc_port) + logging.info(f"Connect to a Rerun Server: rerun rerun+http://IP:{grpc_port}/proxy") + rr.serve_web_viewer(open_browser=False, web_port=web_port, connect_to=server_uri) logging.info("Logging to Rerun") @@ -226,7 +228,7 @@ def main(): "Mode of viewing between 'local' or 'distant'. " "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. " "'distant' creates a server on the distant machine where the data is stored. " - "Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine." + "Visualize the data by connecting to the server with `rerun rerun+http://IP:GRPC_PORT/proxy` on the local machine." ), ) parser.add_argument( @@ -238,8 +240,13 @@ def main(): parser.add_argument( "--ws-port", type=int, - default=9087, - help="Web socket port for rerun.io when `--mode distant` is set.", + help="deprecated, please use --grpc-port instead.", + ) + parser.add_argument( + "--grpc-port", + type=int, + default=9876, + help="gRPC port for rerun.io when `--mode distant` is set.", ) parser.add_argument( "--save", @@ -265,9 +272,7 @@ def main(): parser.add_argument( "--display-compressed-images", - type=bool, - required=True, - default=False, + action="store_true", help="If set, display compressed images in Rerun instead of uncompressed ones.", ) @@ -277,6 +282,14 @@ def main(): root = kwargs.pop("root") tolerance_s = kwargs.pop("tolerance_s") + if kwargs["ws_port"] is not None: + logging.warning( + "--ws-port is deprecated and will be removed in future versions. Please use --grpc-port instead." + ) + logging.warning("Setting grpc_port to ws_port value.") + kwargs["grpc_port"] = kwargs.pop("ws_port") + + init_logging() logging.info("Loading dataset") dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s) From fcbf550952b3794e425e20d01ec76475be54be4e Mon Sep 17 00:00:00 2001 From: HUANG TZU-CHUN Date: Wed, 18 Feb 2026 18:27:40 +0800 Subject: [PATCH 097/111] fix(docs): update environment variable name to HF_LEROBOT_HOME in docstring (#2973) Co-authored-by: Steven Palma --- src/lerobot/datasets/lerobot_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 36bffa190..360ed8d30 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -656,7 +656,7 @@ class LeRobotDataset(torch.utils.data.Dataset): repo_id (str): This is the repo id that will be used to fetch the dataset. Locally, the dataset will be stored under root/repo_id. root (Path | None, optional): Local directory to use for downloading/writing files. You can also - set the LEROBOT_HOME environment variable to point to a different location. Defaults to + set the HF_LEROBOT_HOME environment variable to point to a different location. Defaults to '~/.cache/huggingface/lerobot'. episodes (list[int] | None, optional): If specified, this will only load episodes specified by their episode_index in this list. Defaults to None. From b22e0315b05447efbc9a0eb1d612192aad0337c2 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 18 Feb 2026 17:32:25 +0100 Subject: [PATCH 098/111] fix(utils): more conservative sleep_margin default value in precise_sleep (#2985) --- src/lerobot/utils/robot_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lerobot/utils/robot_utils.py b/src/lerobot/utils/robot_utils.py index 28c8e7c49..656dc2649 100644 --- a/src/lerobot/utils/robot_utils.py +++ b/src/lerobot/utils/robot_utils.py @@ -16,14 +16,14 @@ import platform import time -def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.003): +def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.005): """ Wait for `seconds` with better precision than time.sleep alone at the expense of more CPU usage. Parameters: - seconds: duration to wait - spin_threshold: if remaining <= spin_threshold -> spin; otherwise sleep (seconds). Default 10ms - - sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 3ms + - sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 5ms Note: The default parameters are chosen to prioritize timing accuracy over CPU usage for the common 30 FPS use case. From 89bd58a9a26ec5820df13866b6ebc1670ed8cd83 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 18 Feb 2026 18:22:35 +0100 Subject: [PATCH 099/111] chore(scripts): warn if we don't respect the target FPS (#2986) --- src/lerobot/scripts/lerobot_record.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 0b39e6fff..216ab22a6 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -398,7 +398,14 @@ def record_loop( ) dt_s = time.perf_counter() - start_loop_t - precise_sleep(max(1 / fps - dt_s, 0.0)) + + sleep_time_s: float = 1 / fps - dt_s + if sleep_time_s < 0: + logging.warning( + f"Record loop is running slower ({1 / dt_s:.1f} Hz) than the target FPS ({fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation" + ) + + precise_sleep(max(sleep_time_s, 0.0)) timestamp = time.perf_counter() - start_episode_t From aaf37070587581b3ffa8a28b6c134e846afe3a2e Mon Sep 17 00:00:00 2001 From: Caroline Pascal Date: Wed, 18 Feb 2026 19:16:53 +0100 Subject: [PATCH 100/111] fix(filtering): fixing episodes filtering in load_nested_dataset to always use .from_parquet() (#2982) --- src/lerobot/datasets/utils.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index 321ecedd5..da186bf30 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -122,19 +122,9 @@ def load_nested_dataset( raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}") with SuppressProgressBars(): - # When no filtering needed, Dataset uses memory-mapped loading for efficiency - # PyArrow loads the entire dataset into memory - if episodes is None: - return Dataset.from_parquet([str(path) for path in paths], features=features) - - arrow_dataset = pa_ds.dataset(paths, format="parquet") - filter_expr = pa_ds.field("episode_index").isin(episodes) - table = arrow_dataset.to_table(filter=filter_expr) - - if features is not None: - table = table.cast(features.arrow_schema) - - return Dataset(table) + # We use .from_parquet() memory-mapped loading for efficiency + filters = pa_ds.field("episode_index").isin(episodes) if episodes is not None else None + return Dataset.from_parquet([str(path) for path in paths], filters=filters, features=features) def get_parquet_num_frames(parquet_path: str | Path) -> int: From bc38261321f377621a05595914798023bc05d301 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 18 Feb 2026 20:05:15 +0100 Subject: [PATCH 101/111] feat(robots): use read_latest() camera (#2987) * feat(robots): use read_latest() camera * fix(test): add read_latest reachy cam mock --- src/lerobot/cameras/camera.py | 2 +- src/lerobot/cameras/opencv/camera_opencv.py | 2 +- src/lerobot/cameras/reachy2_camera/reachy2_camera.py | 2 +- src/lerobot/cameras/realsense/camera_realsense.py | 2 +- src/lerobot/robots/hope_jr/hope_jr_arm.py | 2 +- src/lerobot/robots/hope_jr/hope_jr_hand.py | 2 +- src/lerobot/robots/koch_follower/koch_follower.py | 2 +- src/lerobot/robots/lekiwi/lekiwi.py | 2 +- src/lerobot/robots/omx_follower/omx_follower.py | 2 +- src/lerobot/robots/openarm_follower/openarm_follower.py | 2 +- src/lerobot/robots/reachy2/robot_reachy2.py | 2 +- src/lerobot/robots/so_follower/so_follower.py | 2 +- src/lerobot/robots/unitree_g1/unitree_g1.py | 2 +- tests/robots/test_reachy2.py | 1 + 14 files changed, 14 insertions(+), 13 deletions(-) diff --git a/src/lerobot/cameras/camera.py b/src/lerobot/cameras/camera.py index 2894e0215..2a53d2544 100644 --- a/src/lerobot/cameras/camera.py +++ b/src/lerobot/cameras/camera.py @@ -150,7 +150,7 @@ class Camera(abc.ABC): """ pass - def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]: """Return the most recent frame captured immediately (Peeking). This method is non-blocking and returns whatever is currently in the diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py index 10b3f21da..f3289ddc7 100644 --- a/src/lerobot/cameras/opencv/camera_opencv.py +++ b/src/lerobot/cameras/opencv/camera_opencv.py @@ -530,7 +530,7 @@ class OpenCVCamera(Camera): return frame @check_if_not_connected - def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]: """Return the most recent frame captured immediately (Peeking). This method is non-blocking and returns whatever is currently in the diff --git a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py index 0c1dc43d8..9bef957bc 100644 --- a/src/lerobot/cameras/reachy2_camera/reachy2_camera.py +++ b/src/lerobot/cameras/reachy2_camera/reachy2_camera.py @@ -201,7 +201,7 @@ class Reachy2Camera(Camera): return self.read() @check_if_not_connected - def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]: """Return the most recent frame captured immediately (Peeking). This method is non-blocking and returns whatever is currently in the diff --git a/src/lerobot/cameras/realsense/camera_realsense.py b/src/lerobot/cameras/realsense/camera_realsense.py index d599cdce0..d80ec8093 100644 --- a/src/lerobot/cameras/realsense/camera_realsense.py +++ b/src/lerobot/cameras/realsense/camera_realsense.py @@ -573,7 +573,7 @@ class RealSenseCamera(Camera): # NOTE(Steven): Missing implementation for depth for now @check_if_not_connected - def read_latest(self, max_age_ms: int = 1000) -> NDArray[Any]: + def read_latest(self, max_age_ms: int = 500) -> NDArray[Any]: """Return the most recent (color) frame captured immediately (Peeking). This method is non-blocking and returns whatever is currently in the diff --git a/src/lerobot/robots/hope_jr/hope_jr_arm.py b/src/lerobot/robots/hope_jr/hope_jr_arm.py index 5fd9c4d1d..e8269ae46 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_arm.py +++ b/src/lerobot/robots/hope_jr/hope_jr_arm.py @@ -140,7 +140,7 @@ class HopeJrArm(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/robots/hope_jr/hope_jr_hand.py b/src/lerobot/robots/hope_jr/hope_jr_hand.py index 1e5c72b72..a05c4bbcb 100644 --- a/src/lerobot/robots/hope_jr/hope_jr_hand.py +++ b/src/lerobot/robots/hope_jr/hope_jr_hand.py @@ -171,7 +171,7 @@ class HopeJrHand(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/robots/koch_follower/koch_follower.py b/src/lerobot/robots/koch_follower/koch_follower.py index fee0adba9..53a32beed 100644 --- a/src/lerobot/robots/koch_follower/koch_follower.py +++ b/src/lerobot/robots/koch_follower/koch_follower.py @@ -193,7 +193,7 @@ class KochFollower(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/robots/lekiwi/lekiwi.py b/src/lerobot/robots/lekiwi/lekiwi.py index 54848f49d..9d11a000f 100644 --- a/src/lerobot/robots/lekiwi/lekiwi.py +++ b/src/lerobot/robots/lekiwi/lekiwi.py @@ -360,7 +360,7 @@ class LeKiwi(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/robots/omx_follower/omx_follower.py b/src/lerobot/robots/omx_follower/omx_follower.py index a171affbd..e0b612c60 100644 --- a/src/lerobot/robots/omx_follower/omx_follower.py +++ b/src/lerobot/robots/omx_follower/omx_follower.py @@ -176,7 +176,7 @@ class OmxFollower(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/robots/openarm_follower/openarm_follower.py b/src/lerobot/robots/openarm_follower/openarm_follower.py index d6794a226..c865f1ec1 100644 --- a/src/lerobot/robots/openarm_follower/openarm_follower.py +++ b/src/lerobot/robots/openarm_follower/openarm_follower.py @@ -241,7 +241,7 @@ class OpenArmFollower(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/robots/reachy2/robot_reachy2.py b/src/lerobot/robots/reachy2/robot_reachy2.py index 6f4eef56c..fb466f85b 100644 --- a/src/lerobot/robots/reachy2/robot_reachy2.py +++ b/src/lerobot/robots/reachy2/robot_reachy2.py @@ -180,7 +180,7 @@ class Reachy2Robot(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() return obs_dict diff --git a/src/lerobot/robots/so_follower/so_follower.py b/src/lerobot/robots/so_follower/so_follower.py index b4d11fe3f..bc72a2b6a 100644 --- a/src/lerobot/robots/so_follower/so_follower.py +++ b/src/lerobot/robots/so_follower/so_follower.py @@ -187,7 +187,7 @@ class SOFollower(Robot): # Capture images from cameras for cam_key, cam in self.cameras.items(): start = time.perf_counter() - obs_dict[cam_key] = cam.async_read() + obs_dict[cam_key] = cam.read_latest() dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") diff --git a/src/lerobot/robots/unitree_g1/unitree_g1.py b/src/lerobot/robots/unitree_g1/unitree_g1.py index 01b4f330e..df0de8f19 100644 --- a/src/lerobot/robots/unitree_g1/unitree_g1.py +++ b/src/lerobot/robots/unitree_g1/unitree_g1.py @@ -324,7 +324,7 @@ class UnitreeG1(Robot): # Cameras - read images from ZMQ cameras for cam_name, cam in self._cameras.items(): - obs[cam_name] = cam.async_read() + obs[cam_name] = cam.read_latest() return obs diff --git a/tests/robots/test_reachy2.py b/tests/robots/test_reachy2.py index d3c44bf5a..d3f32b1c2 100644 --- a/tests/robots/test_reachy2.py +++ b/tests/robots/test_reachy2.py @@ -142,6 +142,7 @@ def _make_reachy2_camera_mock(*args, **kwargs): cam.connect = MagicMock() cam.disconnect = MagicMock() cam.async_read = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8)) + cam.read_latest = MagicMock(side_effect=lambda: np.zeros((height, width, 3), dtype=np.uint8)) return cam From 5f15232271a81ee6be16cec1960e300f55f25466 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 18 Feb 2026 22:46:12 +0100 Subject: [PATCH 102/111] chore: remove usernames + use entrypoints in docs, comments & sample commands (#2988) --- benchmarks/video/README.md | 84 +++++++++---------- docs/source/earthrover_mini_plus.mdx | 2 +- docs/source/hope_jr.mdx | 10 +-- docs/source/pi0.mdx | 2 +- docs/source/pi05.mdx | 2 +- docs/source/sarm.mdx | 8 +- docs/source/unitree_g1.mdx | 4 +- docs/source/walloss.mdx | 2 +- docs/source/xvla.mdx | 2 +- examples/backward_compatibility/replay.py | 2 +- examples/rtc/eval_dataset.py | 20 ++--- examples/rtc/eval_with_real_robot.py | 6 +- .../v30/convert_dataset_v21_to_v30.py | 2 +- .../policies/sarm/compute_rabc_weights.py | 10 +-- .../policies/smolvla/modeling_smolvla.py | 4 +- src/lerobot/scripts/lerobot_edit_dataset.py | 32 +++---- src/lerobot/scripts/lerobot_replay.py | 2 +- 17 files changed, 97 insertions(+), 97 deletions(-) diff --git a/benchmarks/video/README.md b/benchmarks/video/README.md index 490a4b495..1feee69c4 100644 --- a/benchmarks/video/README.md +++ b/benchmarks/video/README.md @@ -28,9 +28,9 @@ We don't expect the same optimal settings for a dataset of images from a simulat For these reasons, we run this benchmark on four representative datasets: - `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera. -- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera. -- `aliberts/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera. -- `aliberts/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera. +- `lerobot/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera. +- `lerobot/paris_street`: (720 x 1280 pixels) real-world outdoor, moving camera. +- `lerobot/kitchen`: (1080 x 1920 pixels) real-world indoor, fixed camera. Note: The datasets used for this benchmark need to be image datasets, not video datasets. @@ -179,7 +179,7 @@ python benchmark/video/run_video_benchmark.py \ --output-dir outputs/video_benchmark \ --repo-ids \ lerobot/pusht_image \ - aliberts/aloha_mobile_shrimp_image \ + lerobot/aloha_mobile_shrimp_image \ --vcodec libx264 libx265 \ --pix-fmt yuv444p yuv420p \ --g 2 20 None \ @@ -203,9 +203,9 @@ python benchmark/video/run_video_benchmark.py \ --output-dir outputs/video_benchmark \ --repo-ids \ lerobot/pusht_image \ - aliberts/aloha_mobile_shrimp_image \ - aliberts/paris_street \ - aliberts/kitchen \ + lerobot/aloha_mobile_shrimp_image \ + lerobot/paris_street \ + lerobot/kitchen \ --vcodec libx264 libx265 \ --pix-fmt yuv444p yuv420p \ --g 1 2 3 4 5 6 10 15 20 40 None \ @@ -221,9 +221,9 @@ python benchmark/video/run_video_benchmark.py \ --output-dir outputs/video_benchmark \ --repo-ids \ lerobot/pusht_image \ - aliberts/aloha_mobile_shrimp_image \ - aliberts/paris_street \ - aliberts/kitchen \ + lerobot/aloha_mobile_shrimp_image \ + lerobot/paris_street \ + lerobot/kitchen \ --vcodec libsvtav1 \ --pix-fmt yuv420p \ --g 1 2 3 4 5 6 10 15 20 40 None \ @@ -252,37 +252,37 @@ Since we're using av1 encoding, we're choosing the `pyav` decoder as `video_read These tables show the results for `g=2` and `crf=30`, using `timestamps-modes=6_frames` and `backend=pyav` -| video_images_size_ratio | vcodec | pix_fmt | | | | -| ---------------------------------- | ---------- | ------- | --------- | --------- | --------- | -| | libx264 | | libx265 | | libsvtav1 | -| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p | -| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% | -| aliberts/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% | -| aliberts/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% | -| aliberts/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% | +| video_images_size_ratio | vcodec | pix_fmt | | | | +| --------------------------------- | ---------- | ------- | --------- | --------- | --------- | +| | libx264 | | libx265 | | libsvtav1 | +| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p | +| lerobot/pusht_image | **16.97%** | 17.58% | 18.57% | 18.86% | 22.06% | +| lerobot/aloha_mobile_shrimp_image | 2.14% | 2.11% | 1.38% | **1.37%** | 5.59% | +| lerobot/paris_street | 2.12% | 2.13% | **1.54%** | **1.54%** | 4.43% | +| lerobot/kitchen | 1.40% | 1.39% | **1.00%** | **1.00%** | 2.52% | -| video_images_load_time_ratio | vcodec | pix_fmt | | | | -| ---------------------------------- | ------- | ------- | -------- | ------- | --------- | -| | libx264 | | libx265 | | libsvtav1 | -| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p | -| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 | -| aliberts/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** | -| aliberts/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** | -| aliberts/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** | +| video_images_load_time_ratio | vcodec | pix_fmt | | | | +| --------------------------------- | ------- | ------- | -------- | ------- | --------- | +| | libx264 | | libx265 | | libsvtav1 | +| repo_id | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p | +| lerobot/pusht_image | 6.45 | 5.19 | **1.90** | 2.12 | 2.47 | +| lerobot/aloha_mobile_shrimp_image | 11.80 | 7.92 | 0.71 | 0.85 | **0.48** | +| lerobot/paris_street | 2.21 | 2.05 | 0.36 | 0.49 | **0.30** | +| lerobot/kitchen | 1.46 | 1.46 | 0.28 | 0.51 | **0.26** | -| | | vcodec | pix_fmt | | | | -| ---------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ | -| | | libx264 | | libx265 | | libsvtav1 | -| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p | -| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 | -| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 | -| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% | -| aliberts/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** | -| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** | -| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** | -| aliberts/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** | -| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** | -| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** | -| aliberts/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** | -| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** | -| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** | +| | | vcodec | pix_fmt | | | | +| --------------------------------- | -------- | -------- | ------------ | -------- | --------- | ------------ | +| | | libx264 | | libx265 | | libsvtav1 | +| repo_id | metric | yuv420p | yuv444p | yuv420p | yuv444p | yuv420p | +| lerobot/pusht_image | avg_mse | 2.90E-04 | **2.03E-04** | 3.13E-04 | 2.29E-04 | 2.19E-04 | +| | avg_psnr | 35.44 | 37.07 | 35.49 | **37.30** | 37.20 | +| | avg_ssim | 98.28% | **98.85%** | 98.31% | 98.84% | 98.72% | +| lerobot/aloha_mobile_shrimp_image | avg_mse | 2.76E-04 | 2.59E-04 | 3.17E-04 | 3.06E-04 | **1.30E-04** | +| | avg_psnr | 35.91 | 36.21 | 35.88 | 36.09 | **40.17** | +| | avg_ssim | 95.19% | 95.18% | 95.00% | 95.05% | **97.73%** | +| lerobot/paris_street | avg_mse | 6.89E-04 | 6.70E-04 | 4.03E-03 | 4.02E-03 | **3.09E-04** | +| | avg_psnr | 33.48 | 33.68 | 32.05 | 32.15 | **35.40** | +| | avg_ssim | 93.76% | 93.75% | 89.46% | 89.46% | **95.46%** | +| lerobot/kitchen | avg_mse | 2.50E-04 | 2.24E-04 | 4.28E-04 | 4.18E-04 | **1.53E-04** | +| | avg_psnr | 36.73 | 37.33 | 36.56 | 36.75 | **39.12** | +| | avg_ssim | 95.47% | 95.58% | 95.52% | 95.53% | **96.82%** | diff --git a/docs/source/earthrover_mini_plus.mdx b/docs/source/earthrover_mini_plus.mdx index d8083336a..dd9c2ad2b 100644 --- a/docs/source/earthrover_mini_plus.mdx +++ b/docs/source/earthrover_mini_plus.mdx @@ -185,7 +185,7 @@ echo $HF_USER Use the standard recording command: ```bash -python src/lerobot/scripts/lerobot_record.py \ +lerobot-record \ --robot.type=earthrover_mini_plus \ --teleop.type=keyboard_rover \ --dataset.repo_id=your_username/dataset_name \ diff --git a/docs/source/hope_jr.mdx b/docs/source/hope_jr.mdx index 856febb95..026cd084a 100644 --- a/docs/source/hope_jr.mdx +++ b/docs/source/hope_jr.mdx @@ -224,7 +224,7 @@ lerobot-record \ --teleop.port=/dev/tty.usbmodem1201 \ --teleop.id=right \ --teleop.side=right \ - --dataset.repo_id=nepyope/hand_record_test_with_video_data \ + --dataset.repo_id=/hand_record_test_with_video_data \ --dataset.single_task="Hand recording test with video data" \ --dataset.num_episodes=1 \ --dataset.episode_time_s=5 \ @@ -241,7 +241,7 @@ lerobot-replay \ --robot.port=/dev/tty.usbmodem58760432281 \ --robot.id=right \ --robot.side=right \ - --dataset.repo_id=nepyope/hand_record_test_with_camera \ + --dataset.repo_id=/hand_record_test_with_camera \ --dataset.episode=0 ``` @@ -249,13 +249,13 @@ lerobot-replay \ ```bash lerobot-train \ - --dataset.repo_id=nepyope/hand_record_test_with_video_data \ + --dataset.repo_id=/hand_record_test_with_video_data \ --policy.type=act \ --output_dir=outputs/train/hopejr_hand \ --job_name=hopejr \ --policy.device=mps \ --wandb.enable=true \ - --policy.repo_id=nepyope/hand_test_policy + --policy.repo_id=/hand_test_policy ``` ### Evaluate @@ -270,7 +270,7 @@ lerobot-record \ --robot.side=right \ --robot.cameras='{"main": {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30}}' \ --display_data=false \ - --dataset.repo_id=nepyope/eval_hopejr \ + --dataset.repo_id=/eval_hopejr \ --dataset.single_task="Evaluate hopejr hand policy" \ --dataset.num_episodes=10 \ --policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model diff --git a/docs/source/pi0.mdx b/docs/source/pi0.mdx index 93e0b4c88..879bbd16d 100644 --- a/docs/source/pi0.mdx +++ b/docs/source/pi0.mdx @@ -60,7 +60,7 @@ policy.type=pi0 For training π₀, you can use the standard LeRobot training script with the appropriate configuration: ```bash -python src/lerobot/scripts/lerobot_train.py \ +lerobot-train \ --dataset.repo_id=your_dataset \ --policy.type=pi0 \ --output_dir=./outputs/pi0_training \ diff --git a/docs/source/pi05.mdx b/docs/source/pi05.mdx index dbf118aa3..8abaca989 100644 --- a/docs/source/pi05.mdx +++ b/docs/source/pi05.mdx @@ -56,7 +56,7 @@ policy.type=pi05 Here's a complete training command for finetuning the base π₀.₅ model on your own dataset: ```bash -python src/lerobot/scripts/lerobot_train.py\ +lerobot-train \ --dataset.repo_id=your_dataset \ --policy.type=pi05 \ --output_dir=./outputs/pi05_training \ diff --git a/docs/source/sarm.mdx b/docs/source/sarm.mdx index 65e49792b..cd488fe1f 100644 --- a/docs/source/sarm.mdx +++ b/docs/source/sarm.mdx @@ -269,7 +269,7 @@ This generates visualizations showing video frames with subtask boundaries overl Train with **no annotations** - uses linear progress from 0 to 1: ```bash -python src/lerobot/scripts/lerobot_train.py \ +lerobot-train \ --dataset.repo_id=your-username/your-dataset \ --policy.type=sarm \ --policy.annotation_mode=single_stage \ @@ -288,7 +288,7 @@ python src/lerobot/scripts/lerobot_train.py \ Train with **dense annotations only** (sparse auto-generated): ```bash -python src/lerobot/scripts/lerobot_train.py \ +lerobot-train \ --dataset.repo_id=your-username/your-dataset \ --policy.type=sarm \ --policy.annotation_mode=dense_only \ @@ -307,7 +307,7 @@ python src/lerobot/scripts/lerobot_train.py \ Train with **both sparse and dense annotations**: ```bash -python src/lerobot/scripts/lerobot_train.py \ +lerobot-train \ --dataset.repo_id=your-username/your-dataset \ --policy.type=sarm \ --policy.annotation_mode=dual \ @@ -468,7 +468,7 @@ This script: Once you have the progress file, train your policy with RA-BC weighting. The progress file is auto-detected from the dataset path (`sarm_progress.parquet`). Currently PI0, PI0.5 and SmolVLA are supported with RA-BC: ```bash -python src/lerobot/scripts/lerobot_train.py \ +lerobot-train \ --dataset.repo_id=your-username/your-dataset \ --policy.type=pi0 \ --use_rabc=true \ diff --git a/docs/source/unitree_g1.mdx b/docs/source/unitree_g1.mdx index ea6bf54ad..4c5d28924 100644 --- a/docs/source/unitree_g1.mdx +++ b/docs/source/unitree_g1.mdx @@ -216,7 +216,7 @@ lerobot-teleoperate \ ### Record Dataset in Simulation ```bash -python -m lerobot.scripts.lerobot_record \ +lerobot-record \ --robot.type=unitree_g1 \ --robot.is_simulation=true \ --robot.cameras='{"global_view": {"type": "zmq", "server_address": "localhost", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \ @@ -266,7 +266,7 @@ lerobot-teleoperate \ ### Record Dataset on Real Robot ```bash -python -m lerobot.scripts.lerobot_record \ +lerobot-record \ --robot.type=unitree_g1 \ --robot.is_simulation=false \ --robot.cameras='{"global_view": {"type": "zmq", "server_address": "172.18.129.215", "port": 5555, "camera_name": "head_camera", "width": 640, "height": 480, "fps": 30}}' \ diff --git a/docs/source/walloss.mdx b/docs/source/walloss.mdx index c0756c087..e9785cc93 100644 --- a/docs/source/walloss.mdx +++ b/docs/source/walloss.mdx @@ -45,7 +45,7 @@ policy.type=wall_x For training WallX, you can use the standard LeRobot training script with the appropriate configuration: ```bash -python src/lerobot/scripts/lerobot_train.py \ +lerobot-train \ --dataset.repo_id=your_dataset \ --policy.type=wall_x \ --output_dir=./outputs/wallx_training \ diff --git a/docs/source/xvla.mdx b/docs/source/xvla.mdx index dd7d1ef57..97e04d4ec 100644 --- a/docs/source/xvla.mdx +++ b/docs/source/xvla.mdx @@ -154,7 +154,7 @@ lerobot-train \ ```bash lerobot-train \ - --dataset.repo_id=pepijn223/bimanual-so100-handover-cube \ + --dataset.repo_id=/bimanual-so100-handover-cube \ --output_dir=./outputs/xvla_bimanual \ --job_name=xvla_so101_training \ --policy.path="lerobot/xvla-base" \ diff --git a/examples/backward_compatibility/replay.py b/examples/backward_compatibility/replay.py index 8de5ba197..f7c47bec5 100644 --- a/examples/backward_compatibility/replay.py +++ b/examples/backward_compatibility/replay.py @@ -22,7 +22,7 @@ lerobot-replay \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.id=black \ - --dataset.repo_id=aliberts/record-test \ + --dataset.repo_id=/record-test \ --dataset.episode=2 ``` """ diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index 4652df107..613fd67d7 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -27,8 +27,8 @@ measuring consistency and ground truth alignment. Usage: # Basic usage with smolvla policy uv run python examples/rtc/eval_dataset.py \ - --policy.path=helper2424/smolvla_check_rtc_last3 \ - --dataset.repo_id=helper2424/check_rtc \ + --policy.path=/smolvla_check_rtc_last3 \ + --dataset.repo_id=/check_rtc \ --rtc.execution_horizon=8 \ --device=mps \ --rtc.max_guidance_weight=10.0 \ @@ -58,16 +58,16 @@ Usage: --device=cuda uv run python examples/rtc/eval_dataset.py \ - --policy.path=lipsop/reuben_pi0 \ - --dataset.repo_id=ReubenLim/so101_cube_in_cup \ + --policy.path=/reuben_pi0 \ + --dataset.repo_id=/so101_cube_in_cup \ --rtc.execution_horizon=8 \ --device=cuda # With torch.compile for faster inference (PyTorch 2.0+) # Note: CUDA graphs disabled by default due to in-place ops in denoising loop uv run python examples/rtc/eval_dataset.py \ - --policy.path=helper2424/smolvla_check_rtc_last3 \ - --dataset.repo_id=helper2424/check_rtc \ + --policy.path=/smolvla_check_rtc_last3 \ + --dataset.repo_id=/check_rtc \ --rtc.execution_horizon=8 \ --device=mps \ --use_torch_compile=true \ @@ -75,8 +75,8 @@ Usage: # With torch.compile on CUDA (CUDA graphs disabled by default) uv run python examples/rtc/eval_dataset.py \ - --policy.path=helper2424/smolvla_check_rtc_last3 \ - --dataset.repo_id=helper2424/check_rtc \ + --policy.path=/smolvla_check_rtc_last3 \ + --dataset.repo_id=/check_rtc \ --rtc.execution_horizon=8 \ --device=cuda \ --use_torch_compile=true \ @@ -84,8 +84,8 @@ Usage: # Enable CUDA graphs (advanced - may cause tensor aliasing errors) uv run python examples/rtc/eval_dataset.py \ - --policy.path=helper2424/smolvla_check_rtc_last3 \ - --dataset.repo_id=helper2424/check_rtc \ + --policy.path=/smolvla_check_rtc_last3 \ + --dataset.repo_id=/check_rtc \ --use_torch_compile=true \ --torch_compile_backend=inductor \ --torch_compile_mode=max-autotune \ diff --git a/examples/rtc/eval_with_real_robot.py b/examples/rtc/eval_with_real_robot.py index 1470899d9..4c803eb7e 100644 --- a/examples/rtc/eval_with_real_robot.py +++ b/examples/rtc/eval_with_real_robot.py @@ -28,7 +28,7 @@ For simulation environments, see eval_with_simulation.py Usage: # Run RTC with Real robot with RTC uv run examples/rtc/eval_with_real_robot.py \ - --policy.path=helper2424/smolvla_check_rtc_last3 \ + --policy.path=/smolvla_check_rtc_last3 \ --policy.device=mps \ --rtc.enabled=true \ --rtc.execution_horizon=20 \ @@ -41,7 +41,7 @@ Usage: # Run RTC with Real robot without RTC uv run examples/rtc/eval_with_real_robot.py \ - --policy.path=helper2424/smolvla_check_rtc_last3 \ + --policy.path=/smolvla_check_rtc_last3 \ --policy.device=mps \ --rtc.enabled=false \ --robot.type=so100_follower \ @@ -53,7 +53,7 @@ Usage: # Run RTC with Real robot with pi0.5 policy uv run examples/rtc/eval_with_real_robot.py \ - --policy.path=helper2424/pi05_check_rtc \ + --policy.path=/pi05_check_rtc \ --policy.device=mps \ --rtc.enabled=true \ --rtc.execution_horizon=20 \ diff --git a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py index 74be6bfa4..7be37a1b1 100644 --- a/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py +++ b/src/lerobot/datasets/v30/convert_dataset_v21_to_v30.py @@ -529,7 +529,7 @@ if __name__ == "__main__": type=str, required=True, help="Repository identifier on Hugging Face: a community or a user name `/` the name of the dataset " - "(e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).", + "(e.g. `lerobot/pusht`, `/aloha_sim_insertion_human`).", ) parser.add_argument( "--branch", diff --git a/src/lerobot/policies/sarm/compute_rabc_weights.py b/src/lerobot/policies/sarm/compute_rabc_weights.py index 5b6ea6e9b..485c1096b 100644 --- a/src/lerobot/policies/sarm/compute_rabc_weights.py +++ b/src/lerobot/policies/sarm/compute_rabc_weights.py @@ -27,18 +27,18 @@ Usage: # Full RA-BC computation with visualizations python src/lerobot/policies/sarm/compute_rabc_weights.py \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\ - --reward-model-path pepijn223/sarm_single_uni4 + --reward-model-path /sarm_single_uni4 # Faster computation with stride (compute every 5 frames, interpolate the rest) python src/lerobot/policies/sarm/compute_rabc_weights.py \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\ - --reward-model-path pepijn223/sarm_single_uni4 \\ + --reward-model-path /sarm_single_uni4 \\ --stride 5 # Visualize predictions only (no RA-BC computation) python src/lerobot/policies/sarm/compute_rabc_weights.py \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\ - --reward-model-path pepijn223/sarm_single_uni4 \\ + --reward-model-path /sarm_single_uni4 \\ --visualize-only \\ --num-visualizations 5 @@ -714,12 +714,12 @@ Examples: # Full RA-BC computation with visualizations python src/lerobot/policies/sarm/compute_rabc_weights.py \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\ - --reward-model-path pepijn223/sarm_single_uni4 + --reward-model-path /sarm_single_uni4 # Visualize predictions only (no RA-BC computation) python src/lerobot/policies/sarm/compute_rabc_weights.py \\ --dataset-repo-id lerobot/aloha_sim_insertion_human \\ - --reward-model-path pepijn223/sarm_single_uni4 \\ + --reward-model-path /sarm_single_uni4 \\ --visualize-only \\ --num-visualizations 10 """, diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 60b968a42..10544a949 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -30,7 +30,7 @@ Example of finetuning the smolvla pretrained model (`smolvla_base`): ```bash lerobot-train \ --policy.path=lerobot/smolvla_base \ ---dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \ +--dataset.repo_id=/svla_so100_task1_v3 \ --batch_size=64 \ --steps=200000 ``` @@ -40,7 +40,7 @@ and an action expert. ```bash lerobot-train \ --policy.type=smolvla \ ---dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \ +--dataset.repo_id=/svla_so100_task1_v3 \ --batch_size=64 \ --steps=200000 ``` diff --git a/src/lerobot/scripts/lerobot_edit_dataset.py b/src/lerobot/scripts/lerobot_edit_dataset.py index 06e256fa2..afdc95efd 100644 --- a/src/lerobot/scripts/lerobot_edit_dataset.py +++ b/src/lerobot/scripts/lerobot_edit_dataset.py @@ -24,100 +24,100 @@ When new_repo_id is specified, creates a new dataset. Usage Examples: Delete episodes 0, 2, and 5 from a dataset: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type delete_episodes \ --operation.episode_indices "[0, 2, 5]" Delete episodes and save to a new dataset: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --new_repo_id lerobot/pusht_filtered \ --operation.type delete_episodes \ --operation.episode_indices "[0, 2, 5]" Split dataset by fractions: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type split \ --operation.splits '{"train": 0.8, "val": 0.2}' Split dataset by episode indices: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type split \ --operation.splits '{"train": [0, 1, 2, 3], "val": [4, 5]}' Split into more than two splits: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type split \ --operation.splits '{"train": 0.6, "val": 0.2, "test": 0.2}' Merge multiple datasets: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht_merged \ --operation.type merge \ --operation.repo_ids "['lerobot/pusht_train', 'lerobot/pusht_val']" Remove camera feature: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type remove_feature \ --operation.feature_names "['observation.images.top']" Modify tasks - set a single task for all episodes (WARNING: modifies in-place): - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type modify_tasks \ --operation.new_task "Pick up the cube and place it" Modify tasks - set different tasks for specific episodes (WARNING: modifies in-place): - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type modify_tasks \ --operation.episode_tasks '{"0": "Task A", "1": "Task B", "2": "Task A"}' Modify tasks - set default task with overrides for specific episodes (WARNING: modifies in-place): - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht \ --operation.type modify_tasks \ --operation.new_task "Default task" \ --operation.episode_tasks '{"5": "Special task for episode 5"}' Convert image dataset to video format and save locally: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ --operation.type convert_image_to_video \ --operation.output_dir /path/to/output/pusht_video Convert image dataset to video format and save with new repo_id: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ --new_repo_id lerobot/pusht_video \ --operation.type convert_image_to_video Convert image dataset to video format and push to hub: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ --new_repo_id lerobot/pusht_video \ --operation.type convert_image_to_video \ --push_to_hub true Show dataset information: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ --operation.type info \ --operation.show_features true Show dataset information without feature details: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --repo_id lerobot/pusht_image \ --operation.type info \ --operation.show_features false Using JSON config file: - python -m lerobot.scripts.lerobot_edit_dataset \ + lerobot-edit-dataset \ --config_path path/to/edit_config.json """ diff --git a/src/lerobot/scripts/lerobot_replay.py b/src/lerobot/scripts/lerobot_replay.py index c9a559d07..8e2a394b9 100644 --- a/src/lerobot/scripts/lerobot_replay.py +++ b/src/lerobot/scripts/lerobot_replay.py @@ -22,7 +22,7 @@ lerobot-replay \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.id=black \ - --dataset.repo_id=aliberts/record-test \ + --dataset.repo_id=/record-test \ --dataset.episode=0 ``` From 2dd366436ed30ed9729b4f18076a54fec7ec589b Mon Sep 17 00:00:00 2001 From: Khalil Date: Thu, 19 Feb 2026 14:35:02 +0100 Subject: [PATCH 103/111] Fix gym-hil integration with the new LeRobot pipeline. (#2482) * Add GymHILAdapterProcessorStep for gym-hil environment integration * Fix action features in control loop for None teleop device with gym-hil * Finalize dataset before pushing to hub for visualization on the hub * Fix neutral action for gripper * fix pre-commit --- src/lerobot/processor/__init__.py | 2 ++ src/lerobot/processor/gym_action_processor.py | 8 +++++ src/lerobot/processor/hil_processor.py | 31 +++++++++++++++++++ src/lerobot/rl/gym_manipulator.py | 15 +++++++-- 4 files changed, 54 insertions(+), 2 deletions(-) diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 164f7da03..0b63e1606 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -44,6 +44,7 @@ from .hil_processor import ( AddTeleopActionAsComplimentaryDataStep, AddTeleopEventsAsInfoStep, GripperPenaltyProcessorStep, + GymHILAdapterProcessorStep, ImageCropResizeProcessorStep, InterventionActionProcessorStep, RewardClassifierProcessorStep, @@ -87,6 +88,7 @@ __all__ = [ "DoneProcessorStep", "EnvAction", "EnvTransition", + "GymHILAdapterProcessorStep", "GripperPenaltyProcessorStep", "hotswap_stats", "IdentityProcessorStep", diff --git a/src/lerobot/processor/gym_action_processor.py b/src/lerobot/processor/gym_action_processor.py index 8fa8cfd86..4f225af92 100644 --- a/src/lerobot/processor/gym_action_processor.py +++ b/src/lerobot/processor/gym_action_processor.py @@ -20,6 +20,7 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature from .converters import to_tensor from .core import EnvAction, EnvTransition, PolicyAction +from .hil_processor import TELEOP_ACTION_KEY from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry @@ -89,6 +90,13 @@ class Numpy2TorchActionProcessorStep(ProcessorStep): torch_action = to_tensor(action, dtype=None) # Preserve original dtype new_transition[TransitionKey.ACTION] = torch_action + complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + if TELEOP_ACTION_KEY in complementary_data: + teleop_action = complementary_data[TELEOP_ACTION_KEY] + if isinstance(teleop_action, EnvAction): + complementary_data[TELEOP_ACTION_KEY] = to_tensor(teleop_action) + new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data + return new_transition def transform_features( diff --git a/src/lerobot/processor/hil_processor.py b/src/lerobot/processor/hil_processor.py index 24b5628fa..34eaeed51 100644 --- a/src/lerobot/processor/hil_processor.py +++ b/src/lerobot/processor/hil_processor.py @@ -312,6 +312,37 @@ class TimeLimitProcessorStep(TruncatedProcessorStep): return features +@ProcessorStepRegistry.register("gym_hil_adapter_processor") +class GymHILAdapterProcessorStep(ProcessorStep): + """ + Adapts the output of the `gym-hil` environment to the format expected by `lerobot` processors. + + This step normalizes the `transition` object by: + 1. Copying `teleop_action` from `info` to `complementary_data`. + 2. Copying `is_intervention` from `info` (using the string key) to `info` (using the enum key). + """ + + def __call__(self, transition: EnvTransition) -> EnvTransition: + info = transition.get(TransitionKey.INFO, {}) + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + + if TELEOP_ACTION_KEY in info: + complementary_data[TELEOP_ACTION_KEY] = info[TELEOP_ACTION_KEY] + + if "is_intervention" in info: + info[TeleopEvents.IS_INTERVENTION] = info["is_intervention"] + + transition[TransitionKey.INFO] = info + transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data + + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + return features + + @dataclass @ProcessorStepRegistry.register("gripper_penalty_processor") class GripperPenaltyProcessorStep(ProcessorStep): diff --git a/src/lerobot/rl/gym_manipulator.py b/src/lerobot/rl/gym_manipulator.py index 1c1cb752f..f5fcb7437 100644 --- a/src/lerobot/rl/gym_manipulator.py +++ b/src/lerobot/rl/gym_manipulator.py @@ -36,6 +36,7 @@ from lerobot.processor import ( DeviceProcessorStep, EnvTransition, GripperPenaltyProcessorStep, + GymHILAdapterProcessorStep, ImageCropResizeProcessorStep, InterventionActionProcessorStep, MapDeltaActionToRobotActionStep, @@ -379,6 +380,7 @@ def make_processors( ] env_pipeline_steps = [ + GymHILAdapterProcessorStep(), Numpy2TorchActionProcessorStep(), VanillaObservationProcessorStep(), AddBatchDimensionProcessorStep(), @@ -608,7 +610,14 @@ def control_loop( dataset = None if cfg.mode == "record": - action_features = teleop_device.action_features + if teleop_device: + action_features = teleop_device.action_features + else: + action_features = { + "dtype": "float32", + "shape": (4,), + "names": ["delta_x", "delta_y", "delta_z", "gripper"], + } features = { ACTION: action_features, REWARD: {"dtype": "float32", "shape": (1,), "names": None}, @@ -656,7 +665,7 @@ def control_loop( # Create a neutral action (no movement) neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32) if use_gripper: - neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay + neutral_action = torch.cat([neutral_action, torch.tensor([0.0])]) # Gripper stay # Use the new step function transition = step_env_and_process_transition( @@ -725,6 +734,8 @@ def control_loop( precise_sleep(max(dt - (time.perf_counter() - step_start_time), 0.0)) if dataset is not None and cfg.dataset.push_to_hub: + logging.info("Finalizing dataset before pushing to hub") + dataset.finalize() logging.info("Pushing dataset to hub") dataset.push_to_hub() From 5865170d36442b907bb35f946e837eee18aafdf1 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Fri, 20 Feb 2026 17:01:46 +0100 Subject: [PATCH 104/111] chore(deps): bump ceil datasets (#2946) --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e5431ada3..0ca1f0432 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,7 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici dependencies = [ # Hugging Face dependencies - "datasets>=4.0.0,<4.2.0", + "datasets>=4.0.0,<5.0.0", "diffusers>=0.27.2,<0.36.0", "huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0", "accelerate>=1.10.0,<2.0.0", From e96339a3b49ac080e11664d2be4737987d861a57 Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 23 Feb 2026 13:57:43 +0100 Subject: [PATCH 105/111] feat(dataset): add streaming video encoding + HW encoder support (#2974) * feat(dataset): init stream encoding * feat(dataset): use threads to fix frame pickle latency * refactor(dataset): remove HW encoded related changes * add lp (#2977) * feat(dataset): add Hw encoding + log drop frames (#2978) * chore(docs): add streaming video encoding guide * fix(dataset): style docs + testing * chore(docs): simplify sttreaming video encoding guide * chore(dataset): add commands + streaming encoding default false + print note if false + queue default is now 30 * chore(docs): add verification note advice * chore(dataset): adjusting defaults & docs for streaming encoding * docs(scripts): improve docstrings * test(dataset): polish streaming encoding tests * chore(dataset): move FYI log related to streaming * chore(dataset): add arg vcodec to suggestions * refactor(dataset): better handling for auto and available vcodec * chore(dataset): change log level * docs(dataset): add note related to training performance vcodec * docs(dataset): add more notes to streaming encoding --------- Co-authored-by: Caroline Pascal Co-authored-by: Pepijn --- docs/source/_toctree.yml | 2 + docs/source/act.mdx | 3 + docs/source/earthrover_mini_plus.mdx | 3 + docs/source/groot.mdx | 9 +- docs/source/hope_jr.mdx | 6 + docs/source/il_robots.mdx | 8 +- docs/source/lerobot-dataset-v3.mdx | 5 +- docs/source/reachy2.mdx | 6 + docs/source/smolvla.mdx | 3 + docs/source/streaming_video_encoding.mdx | 155 ++++ docs/source/unitree_g1.mdx | 10 +- src/lerobot/datasets/lerobot_dataset.py | 124 ++- src/lerobot/datasets/video_utils.py | 480 +++++++++++- src/lerobot/scripts/lerobot_record.py | 34 +- tests/datasets/test_datasets.py | 9 +- .../datasets/test_streaming_video_encoder.py | 730 ++++++++++++++++++ 16 files changed, 1532 insertions(+), 55 deletions(-) create mode 100644 docs/source/streaming_video_encoding.mdx create mode 100644 tests/datasets/test_streaming_video_encoder.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index d61aac9c1..1055975d7 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -29,6 +29,8 @@ title: Using the Dataset Tools - local: dataset_subtask title: Using Subtasks in the Dataset + - local: streaming_video_encoding + title: Streaming Video Encoding title: "Datasets" - sections: - local: act diff --git a/docs/source/act.mdx b/docs/source/act.mdx index e3294ca69..453bcbba8 100644 --- a/docs/source/act.mdx +++ b/docs/source/act.mdx @@ -88,5 +88,8 @@ lerobot-record \ --dataset.repo_id=${HF_USER}/eval_act_your_dataset \ --dataset.num_episodes=10 \ --dataset.single_task="Your task description" \ + --dataset.streaming_encoding=true \ + --dataset.encoder_threads=2 \ + # --dataset.vcodec=auto \ --policy.path=${HF_USER}/act_policy ``` diff --git a/docs/source/earthrover_mini_plus.mdx b/docs/source/earthrover_mini_plus.mdx index dd9c2ad2b..cfc3a2eef 100644 --- a/docs/source/earthrover_mini_plus.mdx +++ b/docs/source/earthrover_mini_plus.mdx @@ -192,6 +192,9 @@ lerobot-record \ --dataset.num_episodes=2 \ --dataset.fps=10 \ --dataset.single_task="Navigate around obstacles" \ + --dataset.streaming_encoding=true \ + --dataset.encoder_threads=2 \ + # --dataset.vcodec=auto \ --display_data=true ``` diff --git a/docs/source/groot.mdx b/docs/source/groot.mdx index 8bfc22996..0ef591466 100644 --- a/docs/source/groot.mdx +++ b/docs/source/groot.mdx @@ -120,9 +120,12 @@ lerobot-record \ --display_data=true \ --dataset.repo_id=/eval_groot-bimanual \ --dataset.num_episodes=10 \ - --dataset.single_task="Grab and handover the red cube to the other arm" - --policy.path=/groot-bimanual # your trained model - --dataset.episode_time_s=30 + --dataset.single_task="Grab and handover the red cube to the other arm" \ + --dataset.streaming_encoding=true \ + --dataset.encoder_threads=2 \ + # --dataset.vcodec=auto \ + --policy.path=/groot-bimanual \ # your trained model + --dataset.episode_time_s=30 \ --dataset.reset_time_s=10 ``` diff --git a/docs/source/hope_jr.mdx b/docs/source/hope_jr.mdx index 026cd084a..8826d9758 100644 --- a/docs/source/hope_jr.mdx +++ b/docs/source/hope_jr.mdx @@ -230,6 +230,9 @@ lerobot-record \ --dataset.episode_time_s=5 \ --dataset.push_to_hub=true \ --dataset.private=true \ + --dataset.streaming_encoding=true \ + --dataset.encoder_threads=2 \ + # --dataset.vcodec=auto \ --display_data=true ``` @@ -273,5 +276,8 @@ lerobot-record \ --dataset.repo_id=/eval_hopejr \ --dataset.single_task="Evaluate hopejr hand policy" \ --dataset.num_episodes=10 \ + --dataset.streaming_encoding=true \ + --dataset.encoder_threads=2 \ + # --dataset.vcodec=auto \ --policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model ``` diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 84dc6f2f6..7fc770b0c 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -185,7 +185,10 @@ lerobot-record \ --display_data=true \ --dataset.repo_id=${HF_USER}/record-test \ --dataset.num_episodes=5 \ - --dataset.single_task="Grab the black cube" + --dataset.single_task="Grab the black cube" \ + --dataset.streaming_encoding=true \ + # --dataset.vcodec=auto \ + --dataset.encoder_threads=2 ``` @@ -515,6 +518,9 @@ lerobot-record \ --display_data=false \ --dataset.repo_id=${HF_USER}/eval_so100 \ --dataset.single_task="Put lego brick into the transparent box" \ + --dataset.streaming_encoding=true \ + --dataset.encoder_threads=2 \ + # --dataset.vcodec=auto \ # <- Teleop optional if you want to teleoperate in between episodes \ # --teleop.type=so100_leader \ # --teleop.port=/dev/ttyACM0 \ diff --git a/docs/source/lerobot-dataset-v3.mdx b/docs/source/lerobot-dataset-v3.mdx index 3521914f2..235a355bd 100644 --- a/docs/source/lerobot-dataset-v3.mdx +++ b/docs/source/lerobot-dataset-v3.mdx @@ -41,7 +41,10 @@ lerobot-record \ --display_data=true \ --dataset.repo_id=${HF_USER}/record-test \ --dataset.num_episodes=5 \ - --dataset.single_task="Grab the black cube" + --dataset.single_task="Grab the black cube" \ + --dataset.streaming_encoding=true \ + # --dataset.vcodec=auto \ + --dataset.encoder_threads=2 ``` See the [recording guide](./il_robots#record-a-dataset) for more details. diff --git a/docs/source/reachy2.mdx b/docs/source/reachy2.mdx index 51b09acd2..1b868711a 100644 --- a/docs/source/reachy2.mdx +++ b/docs/source/reachy2.mdx @@ -159,6 +159,9 @@ lerobot-record \ --dataset.fps=15 \ --dataset.push_to_hub=true \ --dataset.private=true \ + --dataset.streaming_encoding=true \ + --dataset.encoder_threads=2 \ + # --dataset.vcodec=auto \ --display_data=true ``` @@ -198,6 +201,9 @@ lerobot-record \ --dataset.fps=15 \ --dataset.push_to_hub=true \ --dataset.private=true \ + --dataset.streaming_encoding=true \ + --dataset.encoder_threads=2 \ + # --dataset.vcodec=auto \ --display_data=true ``` diff --git a/docs/source/smolvla.mdx b/docs/source/smolvla.mdx index a56298b5e..bf8a0d2f0 100644 --- a/docs/source/smolvla.mdx +++ b/docs/source/smolvla.mdx @@ -106,6 +106,9 @@ lerobot-record \ --dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub --dataset.episode_time_s=50 \ --dataset.num_episodes=10 \ + --dataset.streaming_encoding=true \ + --dataset.encoder_threads=2 \ + # --dataset.vcodec=auto \ # <- Teleop optional if you want to teleoperate in between episodes \ # --teleop.type=so100_leader \ # --teleop.port=/dev/ttyACM0 \ diff --git a/docs/source/streaming_video_encoding.mdx b/docs/source/streaming_video_encoding.mdx new file mode 100644 index 000000000..40004200e --- /dev/null +++ b/docs/source/streaming_video_encoding.mdx @@ -0,0 +1,155 @@ +# Streaming Video Encoding Guide + +## 1. Overview + +Streaming video encoding eliminates the traditional PNG round-trip during video dataset recording. Instead of: + +1. Capture frame -> write PNG to disk -> (at episode end) read PNG's -> encode to MP4 -> delete PNG's + +Frames can be encoded in real-time during capture: + +1. Capture frame -> queue to encoder thread -> encode to MP4 directly + +This makes `save_episode()` near-instant (the video is already encoded by the time the episode ends) and removes the blocking wait that previously occurred between episodes, especially with multiple cameras in long episodes. + +## 2. Tuning Parameters + +| Parameter | CLI Flag | Type | Default | Description | +| ----------------------- | --------------------------------- | ------------- | ------------- | ----------------------------------------------------------------- | +| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture | +| `vcodec` | `--dataset.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder | +| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance. `None` will leave the vcoded decide | +| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM | + +## 3. Performance Considerations + +Streaming encoding means the CPU is encoding video **during** the capture loop, not after. This creates a CPU budget that must be shared between: + +- **Control loop** (reading cameras, control the robot, writing non-video data) +- **Encoder threads** (one pool per camera) +- **Rerun visualization** (if enabled) +- **OS and other processes** + +### Resolution & Number of Cameras Impact + +| Setup | Throughput (px/sec) | CPU Encoding Load | Notes | +| ------------------------- | ------------------- | ----------------- | ------------------------------ | +| 2camsx 640x480x3 @30fps | 55M | Low | Works on most systems | +| 2camsx 1280x720x3 @30fps | 165M | Moderate | Comfortable on modern systems | +| 2camsx 1920x1080x3 @30fps | 373M | High | Requires powerful high-end CPU | + +### `encoder_threads` Tuning + +This parameter controls how many threads each encoder instance uses internally: + +- **Higher values** (e.g., 4-5): Faster encoding, but uses more CPU cores per camera. Good for high-end systems with many cores. +- **Lower values** (e.g., 1-2): Less CPU per camera, freeing cores for capture and visualization. Good for low-res images and capable CPUs. +- **`None` (default)**: Lets the codec decide. Information available in the codec logs. + +### Backpressure and Frame Dropping + +Each camera has a bounded queue (`encoder_queue_maxsize`, default 60 frames). When the encoder can't keep up: + +1. The queue fills up (consuming RAM) +2. New frames are **dropped** (not blocked) — the capture loop continues uninterrupted +3. A warning is logged: `"Encoder queue full for {camera}, dropped N frame(s)"` +4. At episode end, total dropped frames per camera are reported + +### Symptoms of Encoder Falling Behind + +- **System feels laggy and freezes**: all CPUs are at 100% +- **Dropped frame warnings** in the log or lower frames/FPS than expected in the recorded dataset +- **Choppy robot movement**: If CPU is severely overloaded, even the capture loop may be affected +- **Accumulated rerun lag**: Visualization falls behind real-time + +## 4. Hardware-Accelerated Encoding + +### When to Use + +Use HW encoding when: + +- CPU is the bottleneck (dropped frames, choppy robot, rerun lag) +- You have compatible hardware (GPU or dedicated encoder) +- You're recording at high throughput (high resolution or with many cameras) + +### Choosing a Codec + +| Codec | CPU Usage | File Size | Quality | Notes | +| --------------------- | --------- | -------------- | ------- | ---------------------------------------------------------------- | +| `libsvtav1` (default) | High | Smallest | Best | Default. Best compression but most CPU-intensive | +| `h264` | Medium | ~30-50% larger | Good | Software H.264. Lower CPU | +| HW encoders | Very Low | Largest | Good | Offloads to dedicated hardware. Best for CPU-constrained systems | + +### Available HW Encoders + +| Encoder | Platform | Hardware | CLI Value | +| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------ | +| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=h264_videotoolbox` | +| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=hevc_videotoolbox` | +| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=h264_nvenc` | +| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=hevc_nvenc` | +| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.vcodec=h264_vaapi` | +| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.vcodec=h264_qsv` | +| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.vcodec=auto` | + +> [!NOTE] +> In order to use the HW accelerated encoders you might need to upgrade your GPU drivers. + +> [!NOTE] +> `libsvtav1` is the default because it provides the best training performance; other vcodecs can reduce CPU usage and be faster, but they typically produce larger files and may affect training time. + +## 5. Troubleshooting + +| Symptom | Likely Cause | Fix | +| ------------------------------------------------------------------ | -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.vcodec=auto`) | +| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.vcodec=auto`). | +| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding | +| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows | +| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` | +| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.vcodec=auto` | +| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. | + +## 6. Recommended Configurations + +These estimates are conservative; we recommend testing them on your setup—start with a low load and increase it gradually. + +### High-End Systems: modern 12+ cores (24+ threads) + +A throughput between ~250-500M px/sec should be comfortable in CPU. For even better results try HW encoding if available. + +```bash +# 3camsx 1280x720x3 @30fps: Defaults work well. Optionally increase encoder parallelism. +# 2camsx 1920x1080x3 @30fps: Defaults work well. Optionally increase encoder parallelism. +lerobot-record --dataset.encoder_threads=5 ... + +# 3camsx 1920x1080x3 @30fps: Might require some tuning. +``` + +### Mid-Range Systems: modern 8+ cores (16+ threads) or Apple Silicon + +A throughput between ~80-300M px/sec should be possible in CPU. + +```bash +# 3camsx 640x480x3 @30fps: Defaults work well. Optionally decrease encoder parallelism. +# 2camsx 1280x720x3 @30fps: Defaults work well. Optionally decrease encoder parallelism. +lerobot-record --dataset.encoder_threads=2 ... + +# 2camsx 1920x1080x3 @30fps: Might require some tuning. +``` + +### Low-Resource Systems: modern 4+ cores (8+ threads) or Raspberry Pi 5 + +On very constrained systems, streaming encoding may compete too heavily with the capture loop. Disabling it falls back to the PNG-based approach where encoding happens between episodes (blocking, but doesn't interfere with capture). Alternatively, record at a lower throughput to reduce both capture and encoding load. Consider also changing codec to `h264` and using batch encoding. + +```bash +# 2camsx 640x480x3 @30fps: Requires some tuning. + +# Use H.264, disable streaming, consider batching encoding +lerobot-record --dataset.vcodec=h264 --dataset.streaming_encoding=false ... +``` + +## 7. Closing note + +Performance ultimately depends on your exact setup — frames-per-second, resolution, CPU cores and load, available memory, episode length, and the encoder you choose. Always test with your target workload, be mindful about your CPU & system capabilities and tune `encoder_threads`, `encoder_queue_maxsize`, and +`vcodec` reasonably. That said, a common practical configuration (for many applications) is three cameras at 640×480x3 @30fps; this usually runs fine with the default streaming video encoding settings in modern systems. Always verify your recorded dataset is healthy by comparing the video duration to the CLI episode duration and confirming the row count equals FPS × CLI duration. diff --git a/docs/source/unitree_g1.mdx b/docs/source/unitree_g1.mdx index 4c5d28924..76e972dca 100644 --- a/docs/source/unitree_g1.mdx +++ b/docs/source/unitree_g1.mdx @@ -229,7 +229,10 @@ lerobot-record \ --dataset.num_episodes=2 \ --dataset.episode_time_s=5 \ --dataset.reset_time_s=5 \ - --dataset.push_to_hub=true + --dataset.push_to_hub=true \ + --dataset.streaming_encoding=true \ + # --dataset.vcodec=auto \ + --dataset.encoder_threads=2 ``` Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim) @@ -279,7 +282,10 @@ lerobot-record \ --dataset.num_episodes=2 \ --dataset.episode_time_s=5 \ --dataset.reset_time_s=5 \ - --dataset.push_to_hub=true + --dataset.push_to_hub=true \ + --dataset.streaming_encoding=true \ + # --dataset.vcodec=auto \ + --dataset.encoder_threads=2 ``` **Note**: Update `server_address` to match your robot's camera server IP. diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 360ed8d30..65b475e26 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -68,6 +68,7 @@ from lerobot.datasets.utils import ( write_tasks, ) from lerobot.datasets.video_utils import ( + StreamingVideoEncoder, VideoFrame, concatenate_video_files, decode_video_frames, @@ -75,11 +76,11 @@ from lerobot.datasets.video_utils import ( get_safe_default_codec, get_video_duration_in_s, get_video_info, + resolve_vcodec, ) from lerobot.utils.constants import HF_LEROBOT_HOME CODEBASE_VERSION = "v3.0" -VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1"} class LeRobotDatasetMetadata: @@ -545,12 +546,19 @@ class LeRobotDatasetMetadata: def _encode_video_worker( - video_key: str, episode_index: int, root: Path, fps: int, vcodec: str = "libsvtav1" + video_key: str, + episode_index: int, + root: Path, + fps: int, + vcodec: str = "libsvtav1", + encoder_threads: int | None = None, ) -> Path: temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4" fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0) img_dir = (root / fpath).parent - encode_video_frames(img_dir, temp_path, fps, vcodec=vcodec, overwrite=True) + encode_video_frames( + img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads + ) shutil.rmtree(img_dir) return temp_path @@ -570,6 +578,9 @@ class LeRobotDataset(torch.utils.data.Dataset): video_backend: str | None = None, batch_encoding_size: int = 1, vcodec: str = "libsvtav1", + streaming_encoding: bool = False, + encoder_queue_maxsize: int = 30, + encoder_threads: int | None = None, ): """ 2 modes are available for instantiating this class, depending on 2 different use cases: @@ -683,12 +694,17 @@ class LeRobotDataset(torch.utils.data.Dataset): batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos. Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1. vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc', - 'libsvtav1'. Defaults to 'libsvtav1'. Use 'h264' for faster encoding on systems where AV1 - encoding is CPU-heavy. + 'libsvtav1', 'auto', or hardware-specific codecs like 'h264_videotoolbox', 'h264_nvenc'. + Defaults to 'libsvtav1'. Use 'auto' to auto-detect the best available hardware encoder. + streaming_encoding (bool, optional): If True, encode video frames in real-time during capture + instead of writing PNG images first. This makes save_episode() near-instant. Defaults to False. + encoder_queue_maxsize (int, optional): Maximum number of frames to buffer per camera when using + streaming encoding. Defaults to 30 (~1s at 30fps). + encoder_threads (int | None, optional): Number of threads per encoder instance. None lets the + codec auto-detect (default). Lower values reduce CPU usage per encoder. Maps to 'lp' (via svtav1-params) for + libsvtav1 and 'threads' for h264/hevc. """ super().__init__() - if vcodec not in VALID_VIDEO_CODECS: - raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}") self.repo_id = repo_id self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id self.image_transforms = image_transforms @@ -700,7 +716,8 @@ class LeRobotDataset(torch.utils.data.Dataset): self.delta_indices = None self.batch_encoding_size = batch_encoding_size self.episodes_since_last_encoding = 0 - self.vcodec = vcodec + self.vcodec = resolve_vcodec(vcodec) + self._encoder_threads = encoder_threads # Unused attributes self.image_writer = None @@ -708,6 +725,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.writer = None self.latest_episode = None self._current_file_start_frame = None # Track the starting frame index of the current parquet file + self._streaming_encoder = None self.root.mkdir(exist_ok=True, parents=True) @@ -749,6 +767,19 @@ class LeRobotDataset(torch.utils.data.Dataset): check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s) self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps) + # Initialize streaming encoder for resumed recording + if streaming_encoding and len(self.meta.video_keys) > 0: + self._streaming_encoder = StreamingVideoEncoder( + fps=self.meta.fps, + vcodec=self.vcodec, + pix_fmt="yuv420p", + g=2, + crf=30, + preset=None, + queue_maxsize=encoder_queue_maxsize, + encoder_threads=encoder_threads, + ) + def _close_writer(self) -> None: """Close and cleanup the parquet writer if it exists.""" writer = getattr(self, "writer", None) @@ -1104,6 +1135,8 @@ class LeRobotDataset(torch.utils.data.Dataset): """ self._close_writer() self.meta._close_writer() + if self._streaming_encoder is not None: + self._streaming_encoder.close() def create_episode_buffer(self, episode_index: int | None = None) -> dict: current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index @@ -1158,6 +1191,13 @@ class LeRobotDataset(torch.utils.data.Dataset): self.episode_buffer["timestamp"].append(timestamp) self.episode_buffer["task"].append(frame.pop("task")) # Remove task from frame after processing + # Start streaming encoder on first frame of episode (once, before iterating keys) + if frame_index == 0 and self._streaming_encoder is not None: + self._streaming_encoder.start_episode( + video_keys=list(self.meta.video_keys), + temp_dir=self.root, + ) + # Add frame features to episode_buffer for key in frame: if key not in self.features: @@ -1165,7 +1205,10 @@ class LeRobotDataset(torch.utils.data.Dataset): f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'." ) - if self.features[key]["dtype"] in ["image", "video"]: + if self.features[key]["dtype"] == "video" and self._streaming_encoder is not None: + self._streaming_encoder.feed_frame(key, frame[key]) + self.episode_buffer[key].append(None) # Placeholder (video keys are skipped in parquet) + elif self.features[key]["dtype"] in ["image", "video"]: img_path = self._get_image_file_path( episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index ) @@ -1226,13 +1269,38 @@ class LeRobotDataset(torch.utils.data.Dataset): # Wait for image writer to end, so that episode stats over images can be computed self._wait_image_writer() - ep_stats = compute_episode_stats(episode_buffer, self.features) - ep_metadata = self._save_episode_data(episode_buffer) has_video_keys = len(self.meta.video_keys) > 0 + use_streaming = self._streaming_encoder is not None and has_video_keys use_batched_encoding = self.batch_encoding_size > 1 - if has_video_keys and not use_batched_encoding: + if use_streaming: + # Compute stats for non-video features only (video stats come from encoder) + non_video_buffer = { + k: v + for k, v in episode_buffer.items() + if self.features.get(k, {}).get("dtype") not in ("video",) + } + non_video_features = {k: v for k, v in self.features.items() if v["dtype"] != "video"} + ep_stats = compute_episode_stats(non_video_buffer, non_video_features) + else: + ep_stats = compute_episode_stats(episode_buffer, self.features) + + ep_metadata = self._save_episode_data(episode_buffer) + + if use_streaming: + # Finish streaming encoding and collect results + streaming_results = self._streaming_encoder.finish_episode() + for video_key in self.meta.video_keys: + temp_path, video_stats = streaming_results[video_key] + if video_stats is not None: + # Format stats same as compute_episode_stats: normalize to [0,1], reshape to (C,1,1) + ep_stats[video_key] = { + k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0) + for k, v in video_stats.items() + } + ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path)) + elif has_video_keys and not use_batched_encoding: num_cameras = len(self.meta.video_keys) if parallel_encoding and num_cameras > 1: # TODO(Steven): Ideally we would like to control the number of threads per encoding such that: @@ -1246,6 +1314,7 @@ class LeRobotDataset(torch.utils.data.Dataset): self.root, self.fps, self.vcodec, + self._encoder_threads, ): video_key for video_key in self.meta.video_keys } @@ -1514,6 +1583,10 @@ class LeRobotDataset(torch.utils.data.Dataset): return metadata def clear_episode_buffer(self, delete_images: bool = True) -> None: + # Cancel streaming encoder if active + if self._streaming_encoder is not None: + self._streaming_encoder.cancel_episode() + # Clean up image files for the current episode buffer if delete_images: # Wait for the async image writer to finish @@ -1561,7 +1634,9 @@ class LeRobotDataset(torch.utils.data.Dataset): Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding, since video encoding with ffmpeg is already using multithreading. """ - return _encode_video_worker(video_key, episode_index, self.root, self.fps, self.vcodec) + return _encode_video_worker( + video_key, episode_index, self.root, self.fps, self.vcodec, self._encoder_threads + ) @classmethod def create( @@ -1578,10 +1653,12 @@ class LeRobotDataset(torch.utils.data.Dataset): video_backend: str | None = None, batch_encoding_size: int = 1, vcodec: str = "libsvtav1", + streaming_encoding: bool = False, + encoder_queue_maxsize: int = 30, + encoder_threads: int | None = None, ) -> "LeRobotDataset": """Create a LeRobot Dataset from scratch in order to record data.""" - if vcodec not in VALID_VIDEO_CODECS: - raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}") + vcodec = resolve_vcodec(vcodec) obj = cls.__new__(cls) obj.meta = LeRobotDatasetMetadata.create( repo_id=repo_id, @@ -1599,6 +1676,7 @@ class LeRobotDataset(torch.utils.data.Dataset): obj.batch_encoding_size = batch_encoding_size obj.episodes_since_last_encoding = 0 obj.vcodec = vcodec + obj._encoder_threads = encoder_threads if image_writer_processes or image_writer_threads: obj.start_image_writer(image_writer_processes, image_writer_threads) @@ -1620,6 +1698,22 @@ class LeRobotDataset(torch.utils.data.Dataset): obj._lazy_loading = False obj._recorded_frames = 0 obj._writer_closed_for_reading = False + + # Initialize streaming encoder + if streaming_encoding and len(obj.meta.video_keys) > 0: + obj._streaming_encoder = StreamingVideoEncoder( + fps=fps, + vcodec=vcodec, + pix_fmt="yuv420p", + g=2, + crf=30, + preset=None, + queue_maxsize=encoder_queue_maxsize, + encoder_threads=encoder_threads, + ) + else: + obj._streaming_encoder = None + return obj diff --git a/src/lerobot/datasets/video_utils.py b/src/lerobot/datasets/video_utils.py index 84ce13772..acc24a9e0 100644 --- a/src/lerobot/datasets/video_utils.py +++ b/src/lerobot/datasets/video_utils.py @@ -13,25 +13,106 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import contextlib import glob import importlib import logging +import queue import shutil import tempfile +import threading import warnings from dataclasses import dataclass, field +from fractions import Fraction from pathlib import Path from threading import Lock from typing import Any, ClassVar import av import fsspec +import numpy as np import pyarrow as pa import torch import torchvision from datasets.features.features import register_feature from PIL import Image +# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build. +# Determines the order of preference for auto-selection when vcodec="auto" is used. +HW_ENCODERS = [ + "h264_videotoolbox", # macOS + "hevc_videotoolbox", # macOS + "h264_nvenc", # NVIDIA GPU + "hevc_nvenc", # NVIDIA GPU + "h264_vaapi", # Linux Intel/AMD + "h264_qsv", # Intel Quick Sync +] + +VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS) + + +def _get_codec_options( + vcodec: str, + g: int | None = 2, + crf: int | None = 30, + preset: int | None = None, +) -> dict: + """Build codec-specific options dict for video encoding.""" + options = {} + + # GOP size (keyframe interval) - supported by VideoToolbox and software encoders + if g is not None and (vcodec in ("h264_videotoolbox", "hevc_videotoolbox") or vcodec not in HW_ENCODERS): + options["g"] = str(g) + + # Quality control (codec-specific parameter names) + if crf is not None: + if vcodec in ("h264", "hevc", "libsvtav1"): + options["crf"] = str(crf) + elif vcodec in ("h264_videotoolbox", "hevc_videotoolbox"): + quality = max(1, min(100, int(100 - crf * 2))) + options["q:v"] = str(quality) + elif vcodec in ("h264_nvenc", "hevc_nvenc"): + options["rc"] = "constqp" + options["qp"] = str(crf) + elif vcodec in ("h264_vaapi",): + options["qp"] = str(crf) + elif vcodec in ("h264_qsv",): + options["global_quality"] = str(crf) + + # Preset (only for libsvtav1) + if vcodec == "libsvtav1": + options["preset"] = str(preset) if preset is not None else "12" + + return options + + +def detect_available_hw_encoders() -> list[str]: + """Probe PyAV/FFmpeg for available hardware video encoders.""" + available = [] + for codec_name in HW_ENCODERS: + try: + av.codec.Codec(codec_name, "w") + available.append(codec_name) + except Exception: # nosec B110 + pass # nosec B110 + return available + + +def resolve_vcodec(vcodec: str) -> str: + """Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1.""" + if vcodec not in VALID_VIDEO_CODECS: + raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}") + if vcodec != "auto": + logging.info(f"Using video codec: {vcodec}") + return vcodec + available = detect_available_hw_encoders() + for encoder in HW_ENCODERS: + if encoder in available: + logging.info(f"Auto-selected video codec: {encoder}") + return encoder + logging.info("No hardware encoder available, falling back to software encoder 'libsvtav1'") + return "libsvtav1" + def get_safe_default_codec(): if importlib.util.find_spec("torchcodec"): @@ -309,14 +390,13 @@ def encode_video_frames( g: int | None = 2, crf: int | None = 30, fast_decode: int = 0, - log_level: int | None = av.logging.ERROR, + log_level: int | None = av.logging.WARNING, overwrite: bool = False, preset: int | None = None, + encoder_threads: int | None = None, ) -> None: """More info on ffmpeg arguments tuning on `benchmark/video/README.md`""" - # Check encoder availability - if vcodec not in ["h264", "hevc", "libsvtav1"]: - raise ValueError(f"Unsupported video codec: {vcodec}. Supported codecs are: h264, hevc, libsvtav1.") + vcodec = resolve_vcodec(vcodec) video_path = Path(video_path) imgs_dir = Path(imgs_dir) @@ -347,21 +427,22 @@ def encode_video_frames( width, height = dummy_image.size # Define video codec options - video_options = {} - - if g is not None: - video_options["g"] = str(g) - - if crf is not None: - video_options["crf"] = str(crf) + video_options = _get_codec_options(vcodec, g, crf, preset) if fast_decode: key = "svtav1-params" if vcodec == "libsvtav1" else "tune" value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode" video_options[key] = value - if vcodec == "libsvtav1": - video_options["preset"] = str(preset) if preset is not None else "12" + if encoder_threads is not None: + if vcodec == "libsvtav1": + lp_param = f"lp={encoder_threads}" + if "svtav1-params" in video_options: + video_options["svtav1-params"] += f":{lp_param}" + else: + video_options["svtav1-params"] = lp_param + else: + video_options["threads"] = str(encoder_threads) # Set logging level if log_level is not None: @@ -480,6 +561,348 @@ def concatenate_video_files( Path(tmp_concatenate_path).unlink() +class _CameraEncoderThread(threading.Thread): + """A thread that encodes video frames streamed via a queue into an MP4 file. + + One instance is created per camera per episode. Frames are received as numpy arrays + from the main thread, encoded in real-time using PyAV (which releases the GIL during + encoding), and written to disk. Stats are computed incrementally using + RunningQuantileStats and returned via result_queue. + """ + + def __init__( + self, + video_path: Path, + fps: int, + vcodec: str, + pix_fmt: str, + g: int | None, + crf: int | None, + preset: int | None, + frame_queue: queue.Queue, + result_queue: queue.Queue, + stop_event: threading.Event, + encoder_threads: int | None = None, + ): + super().__init__(daemon=True) + self.video_path = video_path + self.fps = fps + self.vcodec = vcodec + self.pix_fmt = pix_fmt + self.g = g + self.crf = crf + self.preset = preset + self.frame_queue = frame_queue + self.result_queue = result_queue + self.stop_event = stop_event + self.encoder_threads = encoder_threads + + def run(self) -> None: + from lerobot.datasets.compute_stats import RunningQuantileStats, auto_downsample_height_width + + container = None + output_stream = None + stats_tracker = RunningQuantileStats() + frame_count = 0 + + try: + logging.getLogger("libav").setLevel(av.logging.WARNING) + + while True: + try: + frame_data = self.frame_queue.get(timeout=1) + except queue.Empty: + if self.stop_event.is_set(): + break + continue + + if frame_data is None: + # Sentinel: flush and close + break + + # Ensure HWC uint8 numpy array + if isinstance(frame_data, np.ndarray): + if frame_data.ndim == 3 and frame_data.shape[0] == 3: + # CHW -> HWC + frame_data = frame_data.transpose(1, 2, 0) + if frame_data.dtype != np.uint8: + frame_data = (frame_data * 255).astype(np.uint8) + + # Open container on first frame (to get width/height) + if container is None: + height, width = frame_data.shape[:2] + video_options = _get_codec_options(self.vcodec, self.g, self.crf, self.preset) + if self.encoder_threads is not None: + if self.vcodec == "libsvtav1": + lp_param = f"lp={self.encoder_threads}" + if "svtav1-params" in video_options: + video_options["svtav1-params"] += f":{lp_param}" + else: + video_options["svtav1-params"] = lp_param + else: + video_options["threads"] = str(self.encoder_threads) + Path(self.video_path).parent.mkdir(parents=True, exist_ok=True) + container = av.open(str(self.video_path), "w") + output_stream = container.add_stream(self.vcodec, self.fps, options=video_options) + output_stream.pix_fmt = self.pix_fmt + output_stream.width = width + output_stream.height = height + output_stream.time_base = Fraction(1, self.fps) + + # Encode frame with explicit timestamps + pil_img = Image.fromarray(frame_data) + video_frame = av.VideoFrame.from_image(pil_img) + video_frame.pts = frame_count + video_frame.time_base = Fraction(1, self.fps) + packet = output_stream.encode(video_frame) + if packet: + container.mux(packet) + + # Update stats with downsampled frame (per-channel stats like compute_episode_stats) + img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW + img_downsampled = auto_downsample_height_width(img_chw) + # Reshape CHW to (H*W, C) for per-channel stats + channels = img_downsampled.shape[0] + img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels) + stats_tracker.update(img_for_stats) + + frame_count += 1 + + # Flush encoder + if output_stream is not None: + packet = output_stream.encode() + if packet: + container.mux(packet) + + if container is not None: + container.close() + + av.logging.restore_default_callback() + + # Get stats and put on result queue + if frame_count >= 2: + stats = stats_tracker.get_statistics() + self.result_queue.put(("ok", stats)) + else: + self.result_queue.put(("ok", None)) + + except Exception as e: + logging.error(f"Encoder thread error: {e}") + if container is not None: + with contextlib.suppress(Exception): + container.close() + self.result_queue.put(("error", str(e))) + + +class StreamingVideoEncoder: + """Manages per-camera encoder threads for real-time video encoding during recording. + + Instead of writing frames as PNG images and then encoding to MP4 at episode end, + this class streams frames directly to encoder threads, eliminating the + PNG round-trip and making save_episode() near-instant. + + Uses threading instead of multiprocessing to avoid the overhead of pickling large + numpy arrays through multiprocessing.Queue. PyAV's encode() releases the GIL, + so encoding runs in parallel with the main recording loop. + """ + + def __init__( + self, + fps: int, + vcodec: str = "libsvtav1", + pix_fmt: str = "yuv420p", + g: int | None = 2, + crf: int | None = 30, + preset: int | None = None, + queue_maxsize: int = 30, + encoder_threads: int | None = None, + ): + self.fps = fps + self.vcodec = resolve_vcodec(vcodec) + self.pix_fmt = pix_fmt + self.g = g + self.crf = crf + self.preset = preset + self.queue_maxsize = queue_maxsize + self.encoder_threads = encoder_threads + + self._frame_queues: dict[str, queue.Queue] = {} + self._result_queues: dict[str, queue.Queue] = {} + self._threads: dict[str, _CameraEncoderThread] = {} + self._stop_events: dict[str, threading.Event] = {} + self._video_paths: dict[str, Path] = {} + self._dropped_frames: dict[str, int] = {} + self._episode_active = False + + def start_episode(self, video_keys: list[str], temp_dir: Path) -> None: + """Start encoder threads for a new episode. + + Args: + video_keys: List of video feature keys (e.g. ["observation.images.laptop"]) + temp_dir: Base directory for temporary MP4 files + """ + if self._episode_active: + self.cancel_episode() + + self._dropped_frames.clear() + + for video_key in video_keys: + frame_queue: queue.Queue = queue.Queue(maxsize=self.queue_maxsize) + result_queue: queue.Queue = queue.Queue(maxsize=1) + stop_event = threading.Event() + + temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir)) + video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4" + + encoder_thread = _CameraEncoderThread( + video_path=video_path, + fps=self.fps, + vcodec=self.vcodec, + pix_fmt=self.pix_fmt, + g=self.g, + crf=self.crf, + preset=self.preset, + frame_queue=frame_queue, + result_queue=result_queue, + stop_event=stop_event, + encoder_threads=self.encoder_threads, + ) + encoder_thread.start() + + self._frame_queues[video_key] = frame_queue + self._result_queues[video_key] = result_queue + self._threads[video_key] = encoder_thread + self._stop_events[video_key] = stop_event + self._video_paths[video_key] = video_path + + self._episode_active = True + + def feed_frame(self, video_key: str, image: np.ndarray) -> None: + """Feed a frame to the encoder for a specific camera. + + A copy of the image is made before enqueueing to prevent race conditions + with camera drivers that may reuse buffers. If the encoder queue is full + (encoder can't keep up), the frame is dropped with a warning instead of + crashing the recording session. + + Args: + video_key: The video feature key + image: numpy array in (H,W,C) or (C,H,W) format, uint8 or float + + Raises: + RuntimeError: If the encoder thread has crashed + """ + if not self._episode_active: + raise RuntimeError("No active episode. Call start_episode() first.") + + thread = self._threads[video_key] + if not thread.is_alive(): + # Check for error + try: + status, msg = self._result_queues[video_key].get_nowait() + if status == "error": + raise RuntimeError(f"Encoder thread for {video_key} crashed: {msg}") + except queue.Empty: + pass + raise RuntimeError(f"Encoder thread for {video_key} is not alive") + + try: + self._frame_queues[video_key].put(image.copy(), timeout=0.1) + except queue.Full: + self._dropped_frames[video_key] = self._dropped_frames.get(video_key, 0) + 1 + count = self._dropped_frames[video_key] + # Log periodically to avoid spam (1st, then every 10th) + if count == 1 or count % 10 == 0: + logging.warning( + f"Encoder queue full for {video_key}, dropped {count} frame(s). " + f"Consider using vcodec='auto' for hardware encoding or increasing encoder_queue_maxsize." + ) + + def finish_episode(self) -> dict[str, tuple[Path, dict | None]]: + """Finish encoding the current episode. + + Sends sentinel values, waits for encoder threads to complete, + and collects results. + + Returns: + Dict mapping video_key to (mp4_path, stats_dict_or_None) + """ + if not self._episode_active: + raise RuntimeError("No active episode to finish.") + + results = {} + + # Report dropped frames + for video_key, count in self._dropped_frames.items(): + if count > 0: + logging.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.") + + # Send sentinel to all queues + for video_key in self._frame_queues: + self._frame_queues[video_key].put(None) + + # Wait for all threads and collect results + for video_key in self._threads: + self._threads[video_key].join(timeout=120) + if self._threads[video_key].is_alive(): + logging.error(f"Encoder thread for {video_key} did not finish in time") + self._stop_events[video_key].set() + self._threads[video_key].join(timeout=5) + results[video_key] = (self._video_paths[video_key], None) + continue + + try: + status, data = self._result_queues[video_key].get(timeout=5) + if status == "error": + raise RuntimeError(f"Encoder thread for {video_key} failed: {data}") + results[video_key] = (self._video_paths[video_key], data) + except queue.Empty: + logging.error(f"No result from encoder thread for {video_key}") + results[video_key] = (self._video_paths[video_key], None) + + self._cleanup() + self._episode_active = False + return results + + def cancel_episode(self) -> None: + """Cancel the current episode, stopping encoder threads and cleaning up.""" + if not self._episode_active: + return + + # Signal all threads to stop + for video_key in self._stop_events: + self._stop_events[video_key].set() + + # Wait for threads to finish + for video_key in self._threads: + self._threads[video_key].join(timeout=5) + + # Clean up temp MP4 files + video_path = self._video_paths.get(video_key) + if video_path is not None and video_path.exists(): + shutil.rmtree(str(video_path.parent), ignore_errors=True) + + self._cleanup() + self._episode_active = False + + def close(self) -> None: + """Close the encoder, canceling any in-progress episode.""" + if self._episode_active: + self.cancel_episode() + + def _cleanup(self) -> None: + """Clean up queues and thread tracking dicts.""" + for q in self._frame_queues.values(): + with contextlib.suppress(Exception): + while not q.empty(): + q.get_nowait() + self._frame_queues.clear() + self._result_queues.clear() + self._threads.clear() + self._stop_events.clear() + self._video_paths.clear() + + @dataclass class VideoFrame: # TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo @@ -514,7 +937,7 @@ with warnings.catch_warnings(): def get_audio_info(video_path: Path | str) -> dict: # Set logging level - logging.getLogger("libav").setLevel(av.logging.ERROR) + logging.getLogger("libav").setLevel(av.logging.WARNING) # Getting audio stream information audio_info = {} @@ -546,7 +969,7 @@ def get_audio_info(video_path: Path | str) -> dict: def get_video_info(video_path: Path | str) -> dict: # Set logging level - logging.getLogger("libav").setLevel(av.logging.ERROR) + logging.getLogger("libav").setLevel(av.logging.WARNING) # Getting video stream information video_info = {} @@ -632,8 +1055,15 @@ class VideoEncodingManager: return self def __exit__(self, exc_type, exc_val, exc_tb): - # Handle any remaining episodes that haven't been batch encoded - if self.dataset.episodes_since_last_encoding > 0: + streaming_encoder = getattr(self.dataset, "_streaming_encoder", None) + + if streaming_encoder is not None: + # Handle streaming encoder cleanup + if exc_type is not None: + streaming_encoder.cancel_episode() + streaming_encoder.close() + elif self.dataset.episodes_since_last_encoding > 0: + # Handle any remaining episodes that haven't been batch encoded if exc_type is not None: logging.info("Exception occurred. Encoding remaining episodes before exit...") else: @@ -650,8 +1080,8 @@ class VideoEncodingManager: # Finalize the dataset to properly close all writers self.dataset.finalize() - # Clean up episode images if recording was interrupted - if exc_type is not None: + # Clean up episode images if recording was interrupted (only for non-streaming mode) + if exc_type is not None and streaming_encoder is None: interrupted_episode_index = self.dataset.num_episodes for key in self.dataset.meta.video_keys: img_dir = self.dataset._get_image_file_path( @@ -665,14 +1095,12 @@ class VideoEncodingManager: # Clean up any remaining images directory if it's empty img_dir = self.dataset.root / "images" - # Check for any remaining PNG files - png_files = list(img_dir.rglob("*.png")) - if len(png_files) == 0: - # Only remove the images directory if no PNG files remain - if img_dir.exists(): + if img_dir.exists(): + png_files = list(img_dir.rglob("*.png")) + if len(png_files) == 0: shutil.rmtree(img_dir) logging.debug("Cleaned up empty images directory") - else: - logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files") + else: + logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files") return False # Don't suppress the original exception diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 216ab22a6..ec04975d4 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -26,8 +26,10 @@ lerobot-record \ --dataset.repo_id=/ \ --dataset.num_episodes=2 \ --dataset.single_task="Grab the cube" \ + --dataset.streaming_encoding=true \ + --dataset.encoder_threads=2 \ --display_data=true - # <- Optional: specify video codec (h264, hevc, libsvtav1). Default is libsvtav1. \ + # <- Optional: specify video codec (auto, h264, hevc, libsvtav1). Default is libsvtav1. \ # --dataset.vcodec=h264 \ # <- Teleop optional if you want to teleoperate to record or in between episodes with a policy \ # --teleop.type=so100_leader \ @@ -58,7 +60,10 @@ lerobot-record \ --display_data=true \ --dataset.repo_id=${HF_USER}/bimanual-so-handover-cube \ --dataset.num_episodes=25 \ - --dataset.single_task="Grab and handover the red cube to the other arm" + --dataset.single_task="Grab and handover the red cube to the other arm" \ + --dataset.streaming_encoding=true \ + # --dataset.vcodec=auto \ + --dataset.encoder_threads=2 ``` """ @@ -179,9 +184,19 @@ class DatasetRecordConfig: # Number of episodes to record before batch encoding videos # Set to 1 for immediate encoding (default behavior), or higher for batched encoding video_encoding_batch_size: int = 1 - # Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1'. - # Use 'h264' for faster encoding on systems where AV1 encoding is CPU-heavy. + # Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto', + # or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'. + # Use 'auto' to auto-detect the best available hardware encoder. vcodec: str = "libsvtav1" + # Enable streaming video encoding: encode frames in real-time during capture instead + # of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding + streaming_encoding: bool = False + # Maximum number of frames to buffer per camera when using streaming encoding. + # ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up. + encoder_queue_maxsize: int = 30 + # Number of threads per encoder instance. None = auto (codec default). + # Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc.. + encoder_threads: int | None = None # Rename map for the observation to override the image and state keys rename_map: dict[str, str] = field(default_factory=dict) @@ -452,6 +467,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset: root=cfg.dataset.root, batch_encoding_size=cfg.dataset.video_encoding_batch_size, vcodec=cfg.dataset.vcodec, + streaming_encoding=cfg.dataset.streaming_encoding, + encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize, + encoder_threads=cfg.dataset.encoder_threads, ) if hasattr(robot, "cameras") and len(robot.cameras) > 0: @@ -474,6 +492,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset: image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), batch_encoding_size=cfg.dataset.video_encoding_batch_size, vcodec=cfg.dataset.vcodec, + streaming_encoding=cfg.dataset.streaming_encoding, + encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize, + encoder_threads=cfg.dataset.encoder_threads, ) # Load pretrained policy @@ -497,6 +518,11 @@ def record(cfg: RecordConfig) -> LeRobotDataset: listener, events = init_keyboard_listener() + if not cfg.dataset.streaming_encoding: + logging.info( + "Streaming encoding is disabled. If you have capable hardware, consider enabling it for way faster episode saving. --dataset.streaming_encoding=true --dataset.encoder_threads=2 # --dataset.vcodec=auto. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding" + ) + with VideoEncodingManager(dataset): recorded_episodes = 0 while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]: diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index 27c51b3c4..6f99eb301 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -31,7 +31,6 @@ from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets.factory import make_dataset from lerobot.datasets.image_writer import image_array_to_pil_image from lerobot.datasets.lerobot_dataset import ( - VALID_VIDEO_CODECS, LeRobotDataset, MultiLeRobotDataset, _encode_video_worker, @@ -45,6 +44,7 @@ from lerobot.datasets.utils import ( hf_transform_to_torch, hw_to_dataset_features, ) +from lerobot.datasets.video_utils import VALID_VIDEO_CODECS from lerobot.envs.factory import make_env_config from lerobot.policies.factory import make_policy_config from lerobot.robots import make_robot_from_config @@ -393,7 +393,7 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory): vid_key: {"dtype": "video", "shape": DUMMY_HWC, "names": ["height", "width", "channels"]}, } ds_mixed = empty_lerobot_dataset_factory( - root=tmp_path / "mixed", features=features_mixed, batch_encoding_size=2 + root=tmp_path / "mixed", features=features_mixed, batch_encoding_size=2, streaming_encoding=False ) ds_mixed.add_frame( { @@ -1450,7 +1450,10 @@ def test_valid_video_codecs_constant(): assert "h264" in VALID_VIDEO_CODECS assert "hevc" in VALID_VIDEO_CODECS assert "libsvtav1" in VALID_VIDEO_CODECS - assert len(VALID_VIDEO_CODECS) == 3 + assert "auto" in VALID_VIDEO_CODECS + assert "h264_videotoolbox" in VALID_VIDEO_CODECS + assert "h264_nvenc" in VALID_VIDEO_CODECS + assert len(VALID_VIDEO_CODECS) == 10 def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory): diff --git a/tests/datasets/test_streaming_video_encoder.py b/tests/datasets/test_streaming_video_encoder.py new file mode 100644 index 000000000..a85db6a8d --- /dev/null +++ b/tests/datasets/test_streaming_video_encoder.py @@ -0,0 +1,730 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for streaming video encoding and hardware-accelerated encoding.""" + +import queue +import threading +from unittest.mock import patch + +import av +import numpy as np +import pytest + +from lerobot.datasets.video_utils import ( + VALID_VIDEO_CODECS, + StreamingVideoEncoder, + _CameraEncoderThread, + _get_codec_options, + detect_available_hw_encoders, + resolve_vcodec, +) +from lerobot.utils.constants import OBS_IMAGES + +# ─── _get_codec_options tests ─── + + +class TestGetCodecOptions: + def test_libsvtav1_defaults(self): + opts = _get_codec_options("libsvtav1") + assert opts["g"] == "2" + assert opts["crf"] == "30" + assert opts["preset"] == "12" + + def test_libsvtav1_custom_preset(self): + opts = _get_codec_options("libsvtav1", preset=8) + assert opts["preset"] == "8" + + def test_h264_options(self): + opts = _get_codec_options("h264", g=10, crf=23) + assert opts["g"] == "10" + assert opts["crf"] == "23" + assert "preset" not in opts + + def test_videotoolbox_options(self): + opts = _get_codec_options("h264_videotoolbox", g=2, crf=30) + assert opts["g"] == "2" + # CRF 30 maps to quality = max(1, min(100, 100 - 30*2)) = 40 + assert opts["q:v"] == "40" + assert "crf" not in opts + + def test_nvenc_options(self): + opts = _get_codec_options("h264_nvenc", g=2, crf=25) + assert opts["rc"] == "constqp" + assert opts["qp"] == "25" + assert "crf" not in opts + # NVENC doesn't support g + assert "g" not in opts + + def test_vaapi_options(self): + opts = _get_codec_options("h264_vaapi", crf=28) + assert opts["qp"] == "28" + + def test_qsv_options(self): + opts = _get_codec_options("h264_qsv", crf=25) + assert opts["global_quality"] == "25" + + def test_no_g_no_crf(self): + opts = _get_codec_options("h264", g=None, crf=None) + assert "g" not in opts + assert "crf" not in opts + + +# ─── HW encoder detection tests ─── + + +class TestHWEncoderDetection: + def test_detect_available_hw_encoders_returns_list(self): + result = detect_available_hw_encoders() + assert isinstance(result, list) + + def test_detect_available_hw_encoders_only_valid(self): + from lerobot.datasets.video_utils import HW_ENCODERS + + result = detect_available_hw_encoders() + for encoder in result: + assert encoder in HW_ENCODERS + + def test_resolve_vcodec_passthrough(self): + assert resolve_vcodec("libsvtav1") == "libsvtav1" + assert resolve_vcodec("h264") == "h264" + + def test_resolve_vcodec_auto_fallback(self): + """When no HW encoders are available, auto should fall back to libsvtav1.""" + with patch("lerobot.datasets.video_utils.detect_available_hw_encoders", return_value=[]): + assert resolve_vcodec("auto") == "libsvtav1" + + def test_resolve_vcodec_auto_picks_hw(self): + """When a HW encoder is available, auto should pick it.""" + with patch( + "lerobot.datasets.video_utils.detect_available_hw_encoders", + return_value=["h264_videotoolbox"], + ): + assert resolve_vcodec("auto") == "h264_videotoolbox" + + def test_resolve_vcodec_auto_returns_valid(self): + """Test that resolve_vcodec('auto') returns a known valid codec.""" + result = resolve_vcodec("auto") + assert result in VALID_VIDEO_CODECS + + def test_hw_encoder_names_accepted_in_validation(self): + """Test that HW encoder names pass validation in VALID_VIDEO_CODECS.""" + assert "auto" in VALID_VIDEO_CODECS + assert "h264_videotoolbox" in VALID_VIDEO_CODECS + assert "h264_nvenc" in VALID_VIDEO_CODECS + + def test_resolve_vcodec_invalid_raises(self): + """Test that resolve_vcodec raises ValueError for invalid codecs.""" + with pytest.raises(ValueError, match="Invalid vcodec"): + resolve_vcodec("not_a_real_codec") + + +# ─── _CameraEncoderThread tests ─── + + +class TestCameraEncoderThread: + def test_encodes_valid_mp4(self, tmp_path): + """Test that the encoder thread creates a valid MP4 file with correct frame count.""" + num_frames = 30 + height, width = 64, 96 + fps = 30 + video_path = tmp_path / "test_output" / "test.mp4" + + frame_queue: queue.Queue = queue.Queue(maxsize=60) + result_queue: queue.Queue = queue.Queue(maxsize=1) + stop_event = threading.Event() + + encoder_thread = _CameraEncoderThread( + video_path=video_path, + fps=fps, + vcodec="libsvtav1", + pix_fmt="yuv420p", + g=2, + crf=30, + preset=13, + frame_queue=frame_queue, + result_queue=result_queue, + stop_event=stop_event, + ) + encoder_thread.start() + + # Feed frames (HWC uint8) + for _ in range(num_frames): + frame = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8) + frame_queue.put(frame) + + # Send sentinel + frame_queue.put(None) + encoder_thread.join(timeout=60) + assert not encoder_thread.is_alive() + + # Check result + status, data = result_queue.get(timeout=5) + assert status == "ok" + assert data is not None # Stats should be returned + assert "mean" in data + assert "std" in data + assert "min" in data + assert "max" in data + assert "count" in data + + # Verify the MP4 file is valid + assert video_path.exists() + with av.open(str(video_path)) as container: + stream = container.streams.video[0] + # The frame count should match + total_frames = sum(1 for _ in container.decode(stream)) + assert total_frames == num_frames + + def test_handles_chw_input(self, tmp_path): + """Test that CHW format input is handled correctly.""" + num_frames = 5 + fps = 30 + video_path = tmp_path / "test_chw" / "test.mp4" + + frame_queue: queue.Queue = queue.Queue(maxsize=60) + result_queue: queue.Queue = queue.Queue(maxsize=1) + stop_event = threading.Event() + + encoder_thread = _CameraEncoderThread( + video_path=video_path, + fps=fps, + vcodec="libsvtav1", + pix_fmt="yuv420p", + g=2, + crf=30, + preset=13, + frame_queue=frame_queue, + result_queue=result_queue, + stop_event=stop_event, + ) + encoder_thread.start() + + # Feed CHW frames + for _ in range(num_frames): + frame = np.random.randint(0, 255, (3, 64, 96), dtype=np.uint8) + frame_queue.put(frame) + + frame_queue.put(None) + encoder_thread.join(timeout=60) + + status, _ = result_queue.get(timeout=5) + assert status == "ok" + assert video_path.exists() + + def test_stop_event_cancellation(self, tmp_path): + """Test that setting the stop event causes the thread to exit.""" + fps = 30 + video_path = tmp_path / "test_cancel" / "test.mp4" + + frame_queue: queue.Queue = queue.Queue(maxsize=60) + result_queue: queue.Queue = queue.Queue(maxsize=1) + stop_event = threading.Event() + + encoder_thread = _CameraEncoderThread( + video_path=video_path, + fps=fps, + vcodec="libsvtav1", + pix_fmt="yuv420p", + g=2, + crf=30, + preset=13, + frame_queue=frame_queue, + result_queue=result_queue, + stop_event=stop_event, + ) + encoder_thread.start() + + # Feed a few frames + for _ in range(3): + frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + frame_queue.put(frame) + + # Signal stop instead of sending sentinel + stop_event.set() + encoder_thread.join(timeout=10) + assert not encoder_thread.is_alive() + + +# ─── StreamingVideoEncoder tests ─── + + +class TestStreamingVideoEncoder: + def test_single_camera_episode(self, tmp_path): + """Test encoding a single camera episode.""" + encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13) + + video_keys = [f"{OBS_IMAGES}.laptop"] + encoder.start_episode(video_keys, tmp_path) + + num_frames = 20 + for _ in range(num_frames): + frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + encoder.feed_frame(f"{OBS_IMAGES}.laptop", frame) + + results = encoder.finish_episode() + assert f"{OBS_IMAGES}.laptop" in results + + mp4_path, stats = results[f"{OBS_IMAGES}.laptop"] + assert mp4_path.exists() + assert stats is not None + + # Verify frame count + with av.open(str(mp4_path)) as container: + stream = container.streams.video[0] + total_frames = sum(1 for _ in container.decode(stream)) + assert total_frames == num_frames + + encoder.close() + + def test_multi_camera_episode(self, tmp_path): + """Test encoding multiple cameras simultaneously.""" + encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30) + + video_keys = [f"{OBS_IMAGES}.laptop", f"{OBS_IMAGES}.phone"] + encoder.start_episode(video_keys, tmp_path) + + num_frames = 15 + for _ in range(num_frames): + frame0 = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + frame1 = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + encoder.feed_frame(video_keys[0], frame0) + encoder.feed_frame(video_keys[1], frame1) + + results = encoder.finish_episode() + + for key in video_keys: + assert key in results + mp4_path, stats = results[key] + assert mp4_path.exists() + assert stats is not None + + encoder.close() + + def test_sequential_episodes(self, tmp_path): + """Test that multiple sequential episodes work correctly.""" + encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30) + video_keys = [f"{OBS_IMAGES}.cam"] + + for ep in range(3): + encoder.start_episode(video_keys, tmp_path) + num_frames = 10 + ep * 5 + for _ in range(num_frames): + frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + encoder.feed_frame(f"{OBS_IMAGES}.cam", frame) + results = encoder.finish_episode() + + mp4_path, stats = results[f"{OBS_IMAGES}.cam"] + assert mp4_path.exists() + + with av.open(str(mp4_path)) as container: + stream = container.streams.video[0] + total_frames = sum(1 for _ in container.decode(stream)) + assert total_frames == num_frames + + encoder.close() + + def test_cancel_episode(self, tmp_path): + """Test that canceling an episode cleans up properly.""" + encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30) + video_keys = [f"{OBS_IMAGES}.cam"] + + encoder.start_episode(video_keys, tmp_path) + + for _ in range(5): + frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + encoder.feed_frame(f"{OBS_IMAGES}.cam", frame) + + encoder.cancel_episode() + + # Should be able to start a new episode after cancel + encoder.start_episode(video_keys, tmp_path) + for _ in range(5): + frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + encoder.feed_frame(f"{OBS_IMAGES}.cam", frame) + results = encoder.finish_episode() + + assert f"{OBS_IMAGES}.cam" in results + encoder.close() + + def test_feed_without_start_raises(self, tmp_path): + """Test that feeding frames without starting an episode raises.""" + encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p") + with pytest.raises(RuntimeError, match="No active episode"): + encoder.feed_frame("cam", np.zeros((64, 96, 3), dtype=np.uint8)) + encoder.close() + + def test_finish_without_start_raises(self, tmp_path): + """Test that finishing without starting raises.""" + encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p") + with pytest.raises(RuntimeError, match="No active episode"): + encoder.finish_episode() + encoder.close() + + def test_close_is_idempotent(self, tmp_path): + """Test that close() can be called multiple times safely.""" + encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p") + encoder.close() + encoder.close() # Should not raise + + def test_video_duration_matches_frame_count(self, tmp_path): + """Test that encoded video duration matches num_frames / fps.""" + encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13) + video_keys = [f"{OBS_IMAGES}.cam"] + encoder.start_episode(video_keys, tmp_path) + + num_frames = 90 # 3 seconds at 30fps + for _ in range(num_frames): + frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + encoder.feed_frame(f"{OBS_IMAGES}.cam", frame) + + results = encoder.finish_episode() + mp4_path, _ = results[f"{OBS_IMAGES}.cam"] + + expected_duration = num_frames / 30.0 # 3.0 seconds + + with av.open(str(mp4_path)) as container: + stream = container.streams.video[0] + total_frames = sum(1 for _ in container.decode(stream)) + if stream.duration is not None: + actual_duration = float(stream.duration * stream.time_base) + else: + actual_duration = float(container.duration / av.time_base) + + assert total_frames == num_frames + # Allow small tolerance for duration due to codec framing + assert abs(actual_duration - expected_duration) < 0.5, ( + f"Video duration {actual_duration:.2f}s != expected {expected_duration:.2f}s" + ) + + encoder.close() + + def test_multi_camera_start_episode_called_once(self, tmp_path): + """Test that with multiple cameras, no frames are lost due to double start_episode.""" + encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30) + + video_keys = [f"{OBS_IMAGES}.cam1", f"{OBS_IMAGES}.cam2"] + encoder.start_episode(video_keys, tmp_path) + + num_frames = 30 + for _ in range(num_frames): + frame0 = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + frame1 = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + encoder.feed_frame(video_keys[0], frame0) + encoder.feed_frame(video_keys[1], frame1) + + results = encoder.finish_episode() + + # Both cameras should have all frames + for key in video_keys: + mp4_path, stats = results[key] + assert mp4_path.exists() + with av.open(str(mp4_path)) as container: + stream = container.streams.video[0] + total_frames = sum(1 for _ in container.decode(stream)) + assert total_frames == num_frames, ( + f"Camera {key}: expected {num_frames} frames, got {total_frames}" + ) + + encoder.close() + + def test_encoder_threads_passed_to_thread(self, tmp_path): + """Test that encoder_threads is stored and passed through to encoder threads.""" + encoder = StreamingVideoEncoder( + fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, encoder_threads=2 + ) + assert encoder.encoder_threads == 2 + + video_keys = [f"{OBS_IMAGES}.cam"] + encoder.start_episode(video_keys, tmp_path) + + # Verify the thread received the encoder_threads value + thread = encoder._threads[f"{OBS_IMAGES}.cam"] + assert thread.encoder_threads == 2 + + # Feed some frames and finish to ensure it works end-to-end + num_frames = 10 + for _ in range(num_frames): + frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + encoder.feed_frame(f"{OBS_IMAGES}.cam", frame) + + results = encoder.finish_episode() + mp4_path, stats = results[f"{OBS_IMAGES}.cam"] + assert mp4_path.exists() + assert stats is not None + + with av.open(str(mp4_path)) as container: + stream = container.streams.video[0] + total_frames = sum(1 for _ in container.decode(stream)) + assert total_frames == num_frames + + encoder.close() + + def test_encoder_threads_none_by_default(self, tmp_path): + """Test that encoder_threads defaults to None (codec auto-detect).""" + encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p") + assert encoder.encoder_threads is None + encoder.close() + + def test_graceful_frame_dropping(self, tmp_path): + """Test that full queue drops frames instead of crashing.""" + encoder = StreamingVideoEncoder( + fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13, queue_maxsize=1 + ) + video_keys = [f"{OBS_IMAGES}.cam"] + encoder.start_episode(video_keys, tmp_path) + + # Feed many frames quickly - with queue_maxsize=1, some will be dropped + num_frames = 50 + for _ in range(num_frames): + frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8) + encoder.feed_frame(f"{OBS_IMAGES}.cam", frame) + + # Should not raise - frames are dropped gracefully + results = encoder.finish_episode() + assert f"{OBS_IMAGES}.cam" in results + + mp4_path, _ = results[f"{OBS_IMAGES}.cam"] + assert mp4_path.exists() + + # Some frames should have been dropped (queue was tiny) + dropped = encoder._dropped_frames.get(f"{OBS_IMAGES}.cam", 0) + # We can't guarantee drops but can verify no crash occurred + assert dropped >= 0 + + encoder.close() + + +# ─── Integration tests with LeRobotDataset ─── + + +class TestStreamingEncoderIntegration: + def test_add_frame_save_episode_streaming(self, tmp_path): + """Full integration test: add_frame -> save_episode with streaming encoding.""" + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + features = { + "observation.images.cam": { + "dtype": "video", + "shape": (64, 96, 3), + "names": ["height", "width", "channels"], + }, + "action": {"dtype": "float32", "shape": (6,), "names": ["j1", "j2", "j3", "j4", "j5", "j6"]}, + } + + dataset = LeRobotDataset.create( + repo_id="test/streaming", + fps=30, + features=features, + root=tmp_path / "streaming_test", + use_videos=True, + streaming_encoding=True, + ) + + assert dataset._streaming_encoder is not None + + num_frames = 20 + for _ in range(num_frames): + frame = { + "observation.images.cam": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8), + "action": np.random.randn(6).astype(np.float32), + "task": "test task", + } + dataset.add_frame(frame) + + dataset.save_episode() + + # Verify dataset metadata + assert dataset.meta.total_episodes == 1 + assert dataset.meta.total_frames == num_frames + + # Verify stats exist for the video key + assert dataset.meta.stats is not None + assert "observation.images.cam" in dataset.meta.stats + assert "action" in dataset.meta.stats + + dataset.finalize() + + def test_streaming_disabled_creates_pngs(self, tmp_path): + """Test that disabling streaming encoding falls back to PNG path.""" + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + features = { + "observation.images.cam": { + "dtype": "video", + "shape": (64, 96, 3), + "names": ["height", "width", "channels"], + }, + "action": {"dtype": "float32", "shape": (6,), "names": ["j1", "j2", "j3", "j4", "j5", "j6"]}, + } + + dataset = LeRobotDataset.create( + repo_id="test/no_streaming", + fps=30, + features=features, + root=tmp_path / "no_streaming_test", + use_videos=True, + streaming_encoding=False, + ) + + assert dataset._streaming_encoder is None + + num_frames = 5 + for _ in range(num_frames): + frame = { + "observation.images.cam": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8), + "action": np.random.randn(6).astype(np.float32), + "task": "test task", + } + dataset.add_frame(frame) + + # With streaming disabled, PNG files should be written + images_dir = dataset.root / "images" + assert images_dir.exists() + + dataset.save_episode() + dataset.finalize() + + def test_multi_episode_streaming(self, tmp_path): + """Test recording multiple episodes with streaming encoding.""" + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + features = { + "observation.images.cam": { + "dtype": "video", + "shape": (64, 96, 3), + "names": ["height", "width", "channels"], + }, + "action": {"dtype": "float32", "shape": (2,), "names": ["j1", "j2"]}, + } + + dataset = LeRobotDataset.create( + repo_id="test/multi_ep", + fps=30, + features=features, + root=tmp_path / "multi_ep_test", + use_videos=True, + streaming_encoding=True, + ) + + for ep in range(3): + num_frames = 10 + ep * 5 + for _ in range(num_frames): + frame = { + "observation.images.cam": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8), + "action": np.random.randn(2).astype(np.float32), + "task": f"task_{ep}", + } + dataset.add_frame(frame) + dataset.save_episode() + + assert dataset.meta.total_episodes == 3 + assert dataset.meta.total_frames == 10 + 15 + 20 + + dataset.finalize() + + def test_clear_episode_buffer_cancels_streaming(self, tmp_path): + """Test that clearing episode buffer cancels streaming encoding.""" + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + features = { + "observation.images.cam": { + "dtype": "video", + "shape": (64, 96, 3), + "names": ["height", "width", "channels"], + }, + "action": {"dtype": "float32", "shape": (2,), "names": ["j1", "j2"]}, + } + + dataset = LeRobotDataset.create( + repo_id="test/cancel", + fps=30, + features=features, + root=tmp_path / "cancel_test", + use_videos=True, + streaming_encoding=True, + ) + + # Add some frames + for _ in range(5): + frame = { + "observation.images.cam": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8), + "action": np.random.randn(2).astype(np.float32), + "task": "task", + } + dataset.add_frame(frame) + + # Cancel and re-record + dataset.clear_episode_buffer() + + # Record a new episode + for _ in range(10): + frame = { + "observation.images.cam": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8), + "action": np.random.randn(2).astype(np.float32), + "task": "task", + } + dataset.add_frame(frame) + dataset.save_episode() + + assert dataset.meta.total_episodes == 1 + assert dataset.meta.total_frames == 10 + + dataset.finalize() + + def test_multi_camera_add_frame_streaming(self, tmp_path): + """Test that start_episode is called once with multiple video keys.""" + from lerobot.datasets.lerobot_dataset import LeRobotDataset + + features = { + "observation.images.cam1": { + "dtype": "video", + "shape": (64, 96, 3), + "names": ["height", "width", "channels"], + }, + "observation.images.cam2": { + "dtype": "video", + "shape": (64, 96, 3), + "names": ["height", "width", "channels"], + }, + "action": {"dtype": "float32", "shape": (2,), "names": ["j1", "j2"]}, + } + + dataset = LeRobotDataset.create( + repo_id="test/multi_cam", + fps=30, + features=features, + root=tmp_path / "multi_cam_test", + use_videos=True, + streaming_encoding=True, + ) + + num_frames = 15 + for _ in range(num_frames): + frame = { + "observation.images.cam1": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8), + "observation.images.cam2": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8), + "action": np.random.randn(2).astype(np.float32), + "task": "test task", + } + dataset.add_frame(frame) + + dataset.save_episode() + + assert dataset.meta.total_episodes == 1 + assert dataset.meta.total_frames == num_frames + + dataset.finalize() From a0c5d193919cccc1125c7d03f4af94013d400b9d Mon Sep 17 00:00:00 2001 From: Yueci Deng Date: Mon, 23 Feb 2026 23:32:59 +0800 Subject: [PATCH 106/111] add metadata_buffer_size to dataset creation (#2998) Signed-off-by: Steven Palma Co-authored-by: Steven Palma --- src/lerobot/datasets/lerobot_dataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index 65b475e26..83d452a44 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -1653,6 +1653,7 @@ class LeRobotDataset(torch.utils.data.Dataset): video_backend: str | None = None, batch_encoding_size: int = 1, vcodec: str = "libsvtav1", + metadata_buffer_size: int = 10, streaming_encoding: bool = False, encoder_queue_maxsize: int = 30, encoder_threads: int | None = None, @@ -1667,6 +1668,7 @@ class LeRobotDataset(torch.utils.data.Dataset): features=features, root=root, use_videos=use_videos, + metadata_buffer_size=metadata_buffer_size, ) obj.repo_id = obj.meta.repo_id obj.root = obj.meta.root From 544cbc5f3874d20f7d0f388ef318dc4838c0906f Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Mon, 23 Feb 2026 16:39:04 +0100 Subject: [PATCH 107/111] feat(motors): add RobStride CAN implementation (#2821) * feat(motors): add initial implementation of robstride Co-authored-by: Virgile * chore(motors): solve some linter * remove kp/kd attribute * code uniformisation between damiao and robstride * remove normalization warning * remove non valid baudrates and small docstring update * remove all useless files. Only keeping robstride.py and table.py * typing for mypy * reduce NameOrId usage * align signature with damiao * put the same helper than in the damiao implementation * bug correction : expect a response after each bus.send --------- Co-authored-by: Virgile --- pyproject.toml | 4 +- src/lerobot/motors/robstride/__init__.py | 18 + src/lerobot/motors/robstride/robstride.py | 1003 +++++++++++++++++++++ src/lerobot/motors/robstride/tables.py | 120 +++ src/lerobot/scripts/lerobot_setup_can.py | 1 + 5 files changed, 1145 insertions(+), 1 deletion(-) create mode 100644 src/lerobot/motors/robstride/__init__.py create mode 100644 src/lerobot/motors/robstride/robstride.py create mode 100644 src/lerobot/motors/robstride/tables.py diff --git a/pyproject.toml b/pyproject.toml index 0ca1f0432..ea3df4a6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,11 +98,13 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"] placo-dep = ["placo>=0.9.6,<0.10.0"] transformers-dep = ["transformers>=4.57.1,<5.0.0"] grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"] +can-dep = ["python-can>=4.2.0,<5.0.0"] # Motors feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"] dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"] -damiao = ["python-can>=4.2.0,<5.0.0"] +damiao = ["lerobot[can-dep]"] +robstride = ["lerobot[can-dep]"] # Robots openarms = ["lerobot[damiao]"] diff --git a/src/lerobot/motors/robstride/__init__.py b/src/lerobot/motors/robstride/__init__.py new file mode 100644 index 000000000..7933ac6fa --- /dev/null +++ b/src/lerobot/motors/robstride/__init__.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python + +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .robstride import RobstrideMotorsBus +from .tables import * diff --git a/src/lerobot/motors/robstride/robstride.py b/src/lerobot/motors/robstride/robstride.py new file mode 100644 index 000000000..f47e41509 --- /dev/null +++ b/src/lerobot/motors/robstride/robstride.py @@ -0,0 +1,1003 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO(Virgile) : Robustify mode control , only the MIT protocole is implemented for now + +import logging +import time +from contextlib import contextmanager +from copy import deepcopy +from functools import cached_property +from types import SimpleNamespace +from typing import TYPE_CHECKING, Any, TypedDict + +from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected +from lerobot.utils.import_utils import _can_available + +if TYPE_CHECKING or _can_available: + import can +else: + can = SimpleNamespace(Message=object, interface=None) +import numpy as np + +from lerobot.utils.errors import DeviceNotConnectedError +from lerobot.utils.utils import enter_pressed, move_cursor_up + +from ..motors_bus import Motor, MotorCalibration, MotorsBusBase, NameOrID, Value +from .tables import ( + AVAILABLE_BAUDRATES, + CAN_CMD_CLEAR_FAULT, + CAN_CMD_DISABLE, + CAN_CMD_ENABLE, + CAN_CMD_SET_ZERO, + DEFAULT_BAUDRATE, + DEFAULT_TIMEOUT_MS, + MODEL_RESOLUTION, + MOTOR_LIMIT_PARAMS, + NORMALIZED_DATA, + PARAM_TIMEOUT, + RUNNING_TIMEOUT, + STATE_CACHE_TTL_S, + ControlMode, + MotorType, +) + +logger = logging.getLogger(__name__) + + +class MotorState(TypedDict): + position: float + velocity: float + torque: float + temp_mos: float + temp_rotor: float + + +class RobstrideMotorsBus(MotorsBusBase): + """ + The Robstride implementation for a MotorsBus using CAN bus communication. + + This class uses python-can for CAN bus communication with Robstride motors. + The motors need to be switched to MIT control mode to be compatible with this implementation. + More details on the protocol can be found in the documentation links below: + - python-can documentation: https://python-can.readthedocs.io/en/stable/ + - Robstride CAN protocol: https://github.com/RobStride/MotorStudio + """ + + # CAN-specific settings + available_baudrates = deepcopy(AVAILABLE_BAUDRATES) + default_baudrate = DEFAULT_BAUDRATE + default_timeout = DEFAULT_TIMEOUT_MS + + # Motor configuration + model_resolution_table = deepcopy(MODEL_RESOLUTION) + normalized_data = deepcopy(NORMALIZED_DATA) + + def __init__( + self, + port: str, + motors: dict[str, Motor], + calibration: dict[str, MotorCalibration] | None = None, + can_interface: str = "auto", + use_can_fd: bool = True, + bitrate: int = 1000000, + data_bitrate: int | None = 5000000, + ): + """ + Initialize the Robstride motors bus. + + Args: + port: CAN interface name (e.g., "can0" for Linux, "/dev/cu.usbmodem*" for macOS) + motors: Dictionary mapping motor names to Motor objects + calibration: Optional calibration data + can_interface: CAN interface type - "auto" (default), "socketcan" (Linux), or "slcan" (macOS/serial) + use_can_fd: Whether to use CAN FD mode (default: True for OpenArms) + bitrate: Nominal bitrate in bps (default: 1000000 = 1 Mbps) + data_bitrate: Data bitrate for CAN FD in bps (default: 5000000 = 5 Mbps), ignored if use_can_fd is False + """ + super().__init__(port, motors, calibration) + self.port = port + self.can_interface = can_interface + self.use_can_fd = use_can_fd + self.bitrate = bitrate + self.data_bitrate = data_bitrate + self.canbus: can.BusABC | None = None + self._is_connected = False + + # Map motor names to CAN IDs + self._motor_can_ids: dict[str, int] = {} + self._recv_id_to_motor: dict[int, str] = {} + + # Store motor types and recv IDs + self._motor_types: dict[str, MotorType] = {} + # Dynamic gains storage (Damiao-style update path via write/sync_write) + self._gains: dict[str, dict[str, float]] = {} + for name, motor in self.motors.items(): + if motor.motor_type_str is not None: + self._motor_types[name] = getattr(MotorType, motor.motor_type_str.upper()) + else: + # Default to O0if not specified + self._motor_types[name] = MotorType.O0 + + # Damiao-style defaults: fixed gains at startup for every motor. + self._gains[name] = {"kp": 10.0, "kd": 0.5} + + # Map recv_id to motor name for filtering responses + if motor.recv_id is not None: + self._recv_id_to_motor[motor.recv_id] = name + # Motor Mode + self.enabled: dict[str, bool] = {} + self.operation_mode: dict[str, ControlMode] = {} + self._last_known_states: dict[str, MotorState] = { + name: { + "position": 0.0, + "velocity": 0.0, + "torque": 0.0, + "temp_mos": 0.0, + "temp_rotor": 0.0, + } + for name in self.motors + } + self.last_feedback_time: dict[str, float | None] = {} + self._id_to_name: dict[int, str] = {} + for name in self.motors: + self.enabled[name] = False + self.operation_mode[name] = ControlMode.MIT # default mode + self.last_feedback_time[name] = None + + for name, motor in self.motors.items(): + key = motor.recv_id if motor.recv_id is not None else motor.id + self._id_to_name[key] = name + + @property + def is_connected(self) -> bool: + """Check if the CAN bus is connected.""" + return self._is_connected and self.canbus is not None + + def _bus(self) -> can.BusABC: + if self.canbus is None: + raise DeviceNotConnectedError(f"{self.__class__.__name__}('{self.port}') is not connected.") + return self.canbus + + @check_if_already_connected + def connect(self, handshake: bool = True) -> None: + """ + Open the CAN bus and initialize communication. + + Args: + handshake: If True, ping all motors to verify they're present + """ + try: + # Auto-detect interface type based on port name + if self.can_interface == "auto": + if self.port.startswith("/dev/"): + self.can_interface = "slcan" + logger.info(f"Auto-detected slcan interface for port {self.port}") + else: + self.can_interface = "socketcan" + logger.info(f"Auto-detected socketcan interface for port {self.port}") + + kwargs = { + "channel": self.port, + "bitrate": self.bitrate, + "interface": self.can_interface, + } + + if self.can_interface == "socketcan" and self.use_can_fd and self.data_bitrate is not None: + kwargs.update({"data_bitrate": self.data_bitrate, "fd": True}) + logger.info( + f"Connected to {self.port} with CAN FD (bitrate={self.bitrate}, data_bitrate={self.data_bitrate})" + ) + else: + logger.info(f"Connected to {self.port} with {self.can_interface} (bitrate={self.bitrate})") + + self.canbus = can.interface.Bus(**kwargs) + + self._is_connected = True + + if handshake: + self._handshake() + + logger.debug(f"{self.__class__.__name__} connected via {self.can_interface}.") + except Exception as e: + self._is_connected = False + raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e + + def _query_status_via_clear_fault(self, motor: NameOrID) -> tuple[bool, can.Message | None]: + motor_name = self._get_motor_name(motor) + motor_id = self._get_motor_id(motor_name) + recv_id = self._get_motor_recv_id(motor_name) + data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT] + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self._bus().send(msg) + return self._recv_status_via_clear_fault(expected_recv_id=recv_id) + + def _recv_status_via_clear_fault( + self, expected_recv_id: int | None = None, timeout: float = RUNNING_TIMEOUT + ) -> tuple[bool, can.Message | None]: + """ + Poll the bus for a response to a fault-clear request. + + Args: + expected_recv_id: Only accept frames from this CAN ID when provided. + timeout: Maximum time spent polling the bus in seconds. + + Returns: + Tuple where the first element is True if a fault frame was received, + and the second element is the CAN message (or None on timeout). + """ + start_time = time.time() + + while time.time() - start_time < timeout: + msg = self._bus().recv(timeout=RUNNING_TIMEOUT / 10) + if not msg: + continue + + if expected_recv_id is not None and msg.data[0] != expected_recv_id: + continue + + # Fault-status frame heuristic (doc-based) + fault_bits = int.from_bytes(msg.data[1:5], "little") + if fault_bits != 0 and msg.data[5] == msg.data[6] == msg.data[7] == 0: + logger.error( + f"Motor fault received from CAN ID 0x{msg.arbitration_id:02X}: " + f"fault_bits=0x{fault_bits:08X}" + ) + return True, msg + + # Otherwise: valid normal response + return False, msg + + return False, None + + def update_motor_state(self, motor: NameOrID) -> bool: + has_fault, msg = self._query_status_via_clear_fault(motor) + if msg is None: + logger.warning(f"No response received from motor '{motor}' during state update.") + raise ConnectionError(f"No response received from motor '{motor}' during state update.") + if has_fault: + logger.error(f"Fault reported by motor '{motor}' during state update. msg={msg.data.hex()}") + raise RuntimeError(f"Fault reported by motor '{motor}' during state update.") + + self._decode_motor_state(msg.data) # updates cache + return True + + def _handshake(self) -> None: + logger.info("Starting handshake with motors...") + missing_motors = [] + faulted_motors = [] + + for motor_name in self.motors: + has_fault, msg = self._query_status_via_clear_fault(motor_name) + if msg is None: + missing_motors.append(motor_name) + elif has_fault: + faulted_motors.append(motor_name) + else: + # CLEAR_FAULT responses are not guaranteed to always match the MIT feedback layout + # on all firmware versions. Handshake should not fail just because cache warm-up fails. + try: + self._decode_motor_state(msg.data) + except Exception as e: + logger.debug( + "Handshake cache warm-up decode failed for motor '%s': %s", + motor_name, + e, + ) + time.sleep(0.01) + + if missing_motors or faulted_motors: + details = [] + if missing_motors: + details.append(f"did not respond: {missing_motors}") + if faulted_motors: + details.append(f"reported fault: {faulted_motors}") + raise ConnectionError("Handshake failed. " + "; ".join(details)) + + logger.info("Handshake successful. All motors ready.") + + def _switch_operation_mode(self, motor: NameOrID, mode: ControlMode) -> None: + """Switch the operation mode of a motor.""" + motor_name = self._get_motor_name(motor) + motor_id = self._get_motor_id(motor_name) + recv_id = self._get_motor_recv_id(motor_name) + data = [0xFF] * 8 + data[6] = mode.value + data[7] = 0xFC + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self._bus().send(msg) + msg = self._recv_motor_response(expected_recv_id=recv_id, timeout=PARAM_TIMEOUT) + if msg is not None: + self.operation_mode[motor_name] = mode + + @check_if_not_connected + def disconnect(self, disable_torque: bool = True) -> None: + """ + Close the CAN bus connection. + + Args: + disable_torque: If True, disable torque on all motors before disconnecting + """ + if disable_torque: + try: + self.disable_torque() + except Exception as e: + logger.warning(f"Failed to disable torque during disconnect: {e}") + + if self.canbus: + self.canbus.shutdown() + self.canbus = None + self._is_connected = False + logger.debug(f"{self.__class__.__name__} disconnected.") + + def configure_motors(self) -> None: + """Configure all motors with default settings.""" + # Robstride motors don't require much configuration in MIT mode + # Just ensure they're enabled + for motor in self.motors: + self._enable_motor(self._get_motor_name(motor)) + self._switch_operation_mode(motor, ControlMode.MIT) + time.sleep(0.01) + + def switch_to_mode(self, mode: ControlMode) -> None: + """Switch operation mode on selected motors.""" + for motor in self.motors: + self._switch_operation_mode(motor, mode) + time.sleep(0.01) + + def _enable_motor(self, motor: NameOrID) -> None: + """Enable a single motor.""" + motor_id = self._get_motor_id(motor) + recv_id = self._get_motor_recv_id(motor) + data = [0xFF] * 7 + [CAN_CMD_ENABLE] + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self._bus().send(msg) + self._recv_motor_response(expected_recv_id=recv_id, timeout=PARAM_TIMEOUT) + + def _disable_motor(self, motor: NameOrID) -> None: + """Disable a single motor.""" + motor_id = self._get_motor_id(motor) + recv_id = self._get_motor_recv_id(motor) + data = [0xFF] * 7 + [CAN_CMD_DISABLE] + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self._bus().send(msg) + self._recv_motor_response(expected_recv_id=recv_id) + + def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + """Enable torque on selected motors.""" + motors = self._get_motors_list(motors) + for motor in motors: + for _ in range(num_retry + 1): + try: + self._get_motor_name(motor) + self._enable_motor(self._get_motor_name(motor)) + break + except Exception as e: + if _ == num_retry: + raise e + time.sleep(0.01) + + def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + """Disable torque on selected motors.""" + motors = self._get_motors_list(motors) + for motor in motors: + for _ in range(num_retry + 1): + try: + self._disable_motor(self._get_motor_name(motor)) + break + except Exception as e: + if _ == num_retry: + raise e + time.sleep(0.01) + + @contextmanager + def torque_disabled(self, motors: str | list[str] | None = None): + """ + Context manager that guarantees torque is re-enabled. + + This helper is useful to temporarily disable torque when configuring motors. + + Examples: + >>> with bus.torque_disabled(): + ... # Safe operations here with torque disabled + ... pass + """ + self.disable_torque(motors) + try: + yield + finally: + self.enable_torque(motors) + + def set_zero_position(self, motors: str | list[str] | None = None) -> None: + """Set current position as zero for selected motors.""" + motors = self._get_motors_list(motors) + for motor in motors: + motor_id = self._get_motor_id(motor) + recv_id = self._get_motor_recv_id(motor) + data = [0xFF] * 7 + [CAN_CMD_SET_ZERO] + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self._bus().send(msg) + self._recv_motor_response(expected_recv_id=recv_id) + time.sleep(0.01) + + def _recv_motor_response( + self, expected_recv_id: int | None = None, timeout: float = 0.001 + ) -> can.Message | None: + """ + Receive a response from a motor. + + Args: + expected_recv_id: If provided, only return messages from this CAN ID + timeout: Timeout in seconds (default: 1ms for high-speed operation) + + Returns: + CAN message if received, None otherwise + """ + try: + start_time = time.time() + messages_seen = [] + while time.time() - start_time < timeout: + msg = self._bus().recv(timeout=RUNNING_TIMEOUT / 10) # 100us timeout for fast polling + if msg: + messages_seen.append(f"0x{msg.arbitration_id:02X}") + # If no filter specified, return any message + if expected_recv_id is None: + return msg + # Otherwise, only return if it matches the expected recv_id + if msg.data[0] == expected_recv_id: + return msg + else: + logger.debug( + f"Ignoring message from CAN ID 0x{msg.arbitration_id:02X}, expected 0x{expected_recv_id:02X}" + ) + + # Only log warnings if we're in debug mode to reduce overhead + if logger.isEnabledFor(logging.DEBUG): + if messages_seen: + logger.debug( + f"Received {len(messages_seen)} message(s) from IDs {set(messages_seen)}, but expected 0x{expected_recv_id:02X}" + ) + else: + logger.debug(f"No CAN messages received (expected from 0x{expected_recv_id:02X})") + except Exception as e: + logger.debug(f"Failed to receive CAN message: {e}") + return None + + def _recv_all_responses( + self, expected_recv_ids: list[int], timeout: float = 0.002 + ) -> dict[int, can.Message]: + """ + Efficiently receive responses from multiple motors at once. + Uses the OpenArms pattern: collect all available messages within timeout. + + Args: + expected_recv_ids: List of CAN IDs we expect responses from + timeout: Total timeout in seconds (default: 2ms) + + Returns: + Dictionary mapping recv_id to CAN message + """ + responses: dict[int, can.Message] = {} + expected_set = set(expected_recv_ids) + start_time = time.time() + + try: + while len(responses) < len(expected_recv_ids) and (time.time() - start_time) < timeout: + msg = self._bus().recv(timeout=RUNNING_TIMEOUT / 10) # 100us poll timeout + if msg and msg.data[0] in expected_set: + responses[msg.data[0]] = msg + if len(responses) == len(expected_recv_ids): + break # Got all responses, exit early + except Exception as e: + logger.debug(f"Error receiving responses: {e}") + + return responses + + def _speed_control( + self, + motor: NameOrID, + velocity_deg_per_sec: float, + current_limit_a: float, + ) -> None: + """ + Send a Velocity Mode Control Command (Command 11) to a single motor. + + Args: + motor: Motor name or CAN ID. + velocity_rad_per_sec: Target speed in rad/s (32-bit float). + current_limit_a: Current limit in A (32-bit float). + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + motor_id = self._get_motor_id(motor) + motor_name = self._get_motor_name(motor) + # Optional: ensure the motor is in velocity control mode + + if self.operation_mode[motor_name] != ControlMode.VEL: + raise RuntimeError(f"Motor '{motor_name}' is not in velocity control mode.") + # Convert to rad/s to match protocol specification + + velocity_rad_per_sec = np.radians(velocity_deg_per_sec) + + # Encode float32 little-endian without struct (byte list) + def _float32_to_le_bytes(x: float) -> list[int]: + b = np.float32(x).tobytes() # 4 bytes, little-endian + return [b[0], b[1], b[2], b[3]] + + speed_bytes = _float32_to_le_bytes(velocity_rad_per_sec) + limit_bytes = _float32_to_le_bytes(current_limit_a) + + data = speed_bytes + limit_bytes # 8 octets : [0–3]=speed, [4–7]=current limit + + msg = can.Message( + arbitration_id=motor_id, + data=data, + is_extended_id=False, + ) + self._bus().send(msg) + + # Si le proto renvoie une réponse type état, on peut la décoder comme pour MIT + recv_id = self._get_motor_recv_id(motor) + if recv_id is not None: + resp = self._recv_motor_response(expected_recv_id=recv_id) + if resp: + self._decode_motor_state(resp.data) + + def _mit_control( + self, + motor: NameOrID, + kp: float, + kd: float, + position_degrees: float, + velocity_deg_per_sec: float, + torque: float, + *, + wait_for_response: bool = True, + ) -> None: + """ + Send MIT control command to a motor. + + Args: + motor: Motor name or ID + kp: Position gain + kd: Velocity gain + position_degrees: Target position (degrees) + velocity_deg_per_sec: Target velocity (degrees/s) + torque: Target torque (N·m) + """ + motor_name = self._get_motor_name(motor) + motor_type = self._motor_types[motor_name] + if self.operation_mode[motor_name] != ControlMode.MIT: + raise RuntimeError(f"Motor '{motor_name}' is not in MIT control mode.") + motor_id = self._get_motor_id(motor) + data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self._bus().send(msg) + + if wait_for_response: + recv_id = self._get_motor_recv_id(motor) + msg = self._recv_motor_response(expected_recv_id=recv_id) + if msg: + self._process_response(motor_name, msg) + + def _encode_mit_packet( + self, + motor_type: MotorType, + kp: float, + kd: float, + position_degrees: float, + velocity_deg_per_sec: float, + torque: float, + ) -> list[int]: + """Encode an MIT control command payload from physical units.""" + position_rad = np.radians(position_degrees) + velocity_rad_per_sec = np.radians(velocity_deg_per_sec) + pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type] + + kp_uint = self._float_to_uint(kp, 0, 500, 12) + kd_uint = self._float_to_uint(kd, 0, 5, 12) + q_uint = self._float_to_uint(position_rad, -pmax, pmax, 16) + dq_uint = self._float_to_uint(velocity_rad_per_sec, -vmax, vmax, 12) + tau_uint = self._float_to_uint(torque, -tmax, tmax, 12) + + data = [0] * 8 + data[0] = (q_uint >> 8) & 0xFF + data[1] = q_uint & 0xFF + data[2] = dq_uint >> 4 + data[3] = ((dq_uint & 0xF) << 4) | ((kp_uint >> 8) & 0xF) + data[4] = kp_uint & 0xFF + data[5] = kd_uint >> 4 + data[6] = ((kd_uint & 0xF) << 4) | ((tau_uint >> 8) & 0xF) + data[7] = tau_uint & 0xFF + return data + + def _mit_control_batch( + self, + commands: dict[NameOrID, tuple[float, float, float, float, float]], + ) -> None: + """Send MIT commands in batch and update cache from collected responses.""" + if not commands: + return + + recv_id_to_motor: dict[int, str] = {} + for motor, (kp, kd, position_degrees, velocity_deg_per_sec, torque) in commands.items(): + motor_name = self._get_motor_name(motor) + if self.operation_mode[motor_name] != ControlMode.MIT: + raise RuntimeError(f"Motor '{motor_name}' is not in MIT control mode.") + + motor_id = self._get_motor_id(motor) + motor_type = self._motor_types[motor_name] + data = self._encode_mit_packet(motor_type, kp, kd, position_degrees, velocity_deg_per_sec, torque) + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self._bus().send(msg) + recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name + + responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=RUNNING_TIMEOUT) + for recv_id, motor_name in recv_id_to_motor.items(): + if msg := responses.get(recv_id): + self._process_response(motor_name, msg) + + def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int: + """Convert float to unsigned integer for CAN transmission.""" + x = max(x_min, min(x_max, x)) # Clamp to range + span = x_max - x_min + data_norm = (x - x_min) / span + return int(data_norm * ((1 << bits) - 1)) + + def _uint_to_float(self, x: int, x_min: float, x_max: float, bits: int) -> float: + """Convert unsigned integer from CAN to float.""" + span = x_max - x_min + data_norm = float(x) / ((1 << bits) - 1) + return data_norm * span + x_min + + def _decode_motor_state(self, data: bytearray | bytes) -> tuple[float, float, float, float]: + """ + Decode motor state from CAN data. + + Returns: + Tuple of (position_degrees, velocity_deg_per_sec, torque, temp_mos) + """ + if len(data) < 8: + raise ValueError("Invalid motor state data") + + # Extract encoded values + motor_id = data[0] + motor_name = self._id_to_name[motor_id] + q_uint = (data[1] << 8) | data[2] + dq_uint = (data[3] << 4) | (data[4] >> 4) + tau_uint = ((data[4] & 0x0F) << 8) | data[5] + t_mos = (data[6] << 8) | data[7] + + motor_type = self._motor_types[motor_name] + # Get motor limits + pmax, vmax, tmax = MOTOR_LIMIT_PARAMS[motor_type] + + # Decode to physical values (radians) + position_rad = self._uint_to_float(q_uint, -pmax, pmax, 16) + velocity_rad_per_sec = self._uint_to_float(dq_uint, -vmax, vmax, 12) + torque = self._uint_to_float(tau_uint, -tmax, tmax, 12) + + # Convert to degrees + position_degrees = np.degrees(position_rad) + velocity_deg_per_sec = np.degrees(velocity_rad_per_sec) + + # Update cached state + self.last_feedback_time[motor_name] = time.time() + self._last_known_states[motor_name] = { + "position": position_degrees, + "velocity": velocity_deg_per_sec, + "torque": torque, + "temp_mos": t_mos / 10, + # Not available in Robstride MIT feedback. + "temp_rotor": 0.0, + } + return position_degrees, velocity_deg_per_sec, torque, t_mos / 10 + + def _process_response(self, motor: str, msg: can.Message) -> None: + """Decode a feedback frame and update the cache for one motor.""" + try: + self._decode_motor_state(msg.data) + except Exception as e: + logger.warning(f"Failed to decode response from {motor}: {e}") + + def _get_cached_value(self, motor: str, data_name: str) -> Value: + """Retrieve a specific value from the state cache.""" + state = self._last_known_states[motor] + mapping: dict[str, Any] = { + "Present_Position": state["position"], + "Present_Velocity": state["velocity"], + "Present_Torque": state["torque"], + "Temperature_MOS": state["temp_mos"], + } + if data_name == "Temperature_Rotor": + raise NotImplementedError("Rotor temperature reading not accessible.") + if data_name not in mapping: + raise ValueError(f"Unknown data_name: {data_name}") + return mapping[data_name] + + @check_if_not_connected + def read( + self, + data_name: str, + motor: str, + ) -> Value: + """Read a value from a single motor. Positions are always in degrees.""" + + # Refresh motor to get latest state + t_init = time.time() + if ( + self.last_feedback_time[motor] is None + or t_init - (self.last_feedback_time[motor] or 0) > STATE_CACHE_TTL_S + ): + self.update_motor_state(motor) + + return self._get_cached_value(motor, data_name) + + @check_if_not_connected + def write( + self, + data_name: str, + motor: str, + value: Value, + ) -> None: + """Write a value to a single motor. Positions are always in degrees.""" + motor_name = self._get_motor_name(motor) + + if data_name in ("Kp", "Kd"): + self._gains[motor_name][data_name.lower()] = float(value) + elif data_name == "Goal_Position": + # Use MIT control with position in degrees + kp = self._gains[motor_name]["kp"] + kd = self._gains[motor_name]["kd"] + self._mit_control(motor, kp, kd, value, 0, 0) + elif data_name == "Goal_Velocity": + # Use Velocity control mode + if self.operation_mode[motor_name] != ControlMode.VEL: + raise RuntimeError(f"Motor '{motor_name}' is not in velocity control mode.") + current_limit_a = 5.0 # Example current limit / not specified in doc. This mode is rarely used and primarily intended for diagnostics + self._speed_control(motor, value, current_limit_a) + else: + raise ValueError(f"Writing {data_name} not supported in MIT mode") + + def sync_read( + self, + data_name: str, + motors: str | list[str] | None = None, + ) -> dict[str, Value]: + """ + Read the same value from multiple motors simultaneously. + Uses batched operations: sends all refresh commands, then collects all responses. + This is MUCH faster than sequential reads (OpenArms pattern). + """ + target_motors = self._get_motors_list(motors) + self._batch_refresh(target_motors) + return {motor: self._get_cached_value(motor, data_name) for motor in target_motors} + + @check_if_not_connected + def sync_write( + self, + data_name: str, + values: dict[str, Value], + ) -> None: + """ + Write different values to multiple motors simultaneously. Positions are always in degrees. + Uses batched operations: sends all commands first, then collects responses when MIT mode is used, otherwise send cmd and wait for response for each motor). + """ + if data_name in ("Kp", "Kd"): + key = data_name.lower() + for motor, val in values.items(): + motor_name = self._get_motor_name(motor) + self._gains[motor_name][key] = float(val) + elif data_name == "Goal_Position": + commands: dict[NameOrID, tuple[float, float, float, float, float]] = {} + for motor, value_degrees in values.items(): + motor_name = self._get_motor_name(motor) + commands[motor] = ( + self._gains[motor_name]["kp"], + self._gains[motor_name]["kd"], + float(value_degrees), + 0.0, + 0.0, + ) + self._mit_control_batch(commands) + else: + # Fall back to individual writes for other data types + for motor, value in values.items(): + self.write(data_name, motor, value) + + def sync_read_all_states( + self, + motors: str | list[str] | None = None, + *, + num_retry: int = 0, + ) -> dict[str, MotorState]: + """ + Read ALL motor states (position, velocity, torque) with Robstride TTL refresh policy. + """ + target_motors = self._get_motors_list(motors) + self._batch_refresh(target_motors) + return {motor: self._last_known_states[motor].copy() for motor in target_motors} + + def _batch_refresh(self, motors: list[str]) -> None: + """Refresh a set of motors and update the feedback cache.""" + init_time = time.time() + updated_motors: list[str] = [] + + for motor in motors: + if ( + self.last_feedback_time[motor] is not None + and (init_time - (self.last_feedback_time[motor] or 0)) < STATE_CACHE_TTL_S + ): + continue + motor_id = self._get_motor_id(motor) + data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT] + msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) + self._bus().send(msg) + updated_motors.append(motor) + + expected_recv_ids = [self._get_motor_recv_id(motor) for motor in updated_motors] + responses = self._recv_all_responses(expected_recv_ids, timeout=RUNNING_TIMEOUT) + + for response in responses.values(): + payload_motor_name = self._recv_id_to_motor.get(response.data[0]) + if payload_motor_name is not None: + self._process_response(payload_motor_name, response) + else: + # Fallback: still attempt to decode based on payload byte0 mapping. + self._decode_motor_state(response.data) + + for motor in updated_motors: + recv_id = self._get_motor_recv_id(motor) + if recv_id not in responses: + logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.") + + def read_calibration(self) -> dict[str, MotorCalibration]: + """Read calibration data from motors.""" + # Robstride motors don't store calibration internally + # Return existing calibration or empty dict + return self.calibration if self.calibration else {} + + def write_calibration(self, calibration_dict: dict[str, MotorCalibration], cache: bool = True) -> None: + """Write calibration data to motors.""" + # Robstride motors don't store calibration internally + # Just cache it in memory + if cache: + self.calibration = calibration_dict + + def record_ranges_of_motion( + self, motors: str | list[str] | None = None, display_values: bool = True + ) -> tuple[dict[str, Value], dict[str, Value]]: + """ + Interactively record the min/max values of each motor in degrees. + + Move the joints by hand (with torque disabled) while the method streams live positions. + Press Enter to finish. + """ + target_motors = self._get_motors_list(motors) + + # Disable torque for manual movement + self.disable_torque(target_motors) + time.sleep(0.1) + + # Get initial positions (already in degrees) + start_positions = self.sync_read("Present_Position", target_motors) + mins = start_positions.copy() + maxes = start_positions.copy() + + print("\nMove joints through their full range of motion. Press ENTER when done.") + user_pressed_enter = False + + while not user_pressed_enter: + positions = self.sync_read("Present_Position", target_motors) + + for motor in target_motors: + if motor in positions: + mins[motor] = int( + min( + positions[motor], + mins.get(motor, positions[motor]), + ) + ) + maxes[motor] = int( + max( + positions[motor], + maxes.get(motor, positions[motor]), + ) + ) + + if display_values: + print("\n" + "=" * 50) + print(f"{'MOTOR':<20} | {'MIN (deg)':>12} | {'POS (deg)':>12} | {'MAX (deg)':>12}") + print("-" * 50) + for motor in target_motors: + if motor in positions: + print( + f"{motor:<20} | {mins[motor]:>12.1f} | {positions[motor]:>12.1f} | {maxes[motor]:>12.1f}" + ) + + if enter_pressed(): + user_pressed_enter = True + + if display_values and not user_pressed_enter: + # Move cursor up to overwrite the previous output + move_cursor_up(len(target_motors) + 4) + + time.sleep(0.05) + + # Re-enable torque + self.enable_torque(target_motors) + + # Validate ranges + for motor in target_motors: + if (motor in mins) and (motor in maxes) and (abs(maxes[motor] - mins[motor]) < 5.0): + raise ValueError(f"Motor {motor} has insufficient range of motion (< 5 degrees)") + + return mins, maxes + + def _get_motors_list(self, motors: str | list[str] | None) -> list[str]: + """Convert motor specification to list of motor names.""" + if motors is None: + return list(self.motors.keys()) + elif isinstance(motors, str): + return [motors] + elif isinstance(motors, list): + return motors + else: + raise TypeError(f"Invalid motors type: {type(motors)}") + + def _get_motor_id(self, motor: NameOrID) -> int: + """Get CAN ID for a motor.""" + if isinstance(motor, str): + if motor in self.motors: + return self.motors[motor].id + else: + raise ValueError(f"Unknown motor: {motor}") + else: + return motor + + def _get_motor_name(self, motor: NameOrID) -> str: + """Get motor name from name or ID.""" + if isinstance(motor, str): + return motor + else: + for name, m in self.motors.items(): + if m.id == motor: + return name + raise ValueError(f"Unknown motor ID: {motor}") + + def _get_motor_recv_id(self, motor: NameOrID) -> int: + """Return the expected ID found in feedback payload byte0 for this motor. + + Robstride MIT feedback frames encode an ID in data[0]. Some setups expose it as + `motor.recv_id`; otherwise we fall back to the configured `motor.id`. + """ + motor_name = self._get_motor_name(motor) + motor_obj = self.motors[motor_name] + + recv_id = getattr(motor_obj, "recv_id", None) + if recv_id is None: + logger.debug( + "Motor '%s' has no recv_id; falling back to motor.id=%s for feedback demux.", + motor_name, + motor_obj.id, + ) + return motor_obj.id + + return recv_id + + @cached_property + def is_calibrated(self) -> bool: + """Check if motors are calibrated.""" + return bool(self.calibration) diff --git a/src/lerobot/motors/robstride/tables.py b/src/lerobot/motors/robstride/tables.py new file mode 100644 index 000000000..2fc1a97b0 --- /dev/null +++ b/src/lerobot/motors/robstride/tables.py @@ -0,0 +1,120 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Configuration tables for Damiao motors.""" + +from enum import IntEnum + + +# Motor type definitions +class MotorType(IntEnum): + O0 = 0 + O1 = 1 + O2 = 2 + O3 = 3 + O4 = 4 + O5 = 5 + ELO5 = 6 + O6 = 7 + + +class CommMode(IntEnum): + PrivateProtocole = 0 + CANopen = 1 + MIT = 2 + + +# Control modes +class ControlMode(IntEnum): + MIT = 0 + POS_VEL = 1 + VEL = 2 + + +# Motor limit parameters [PMAX, VMAX, TMAX] +# PMAX: Maximum position (rad) +# VMAX: Maximum velocity (rad/s) +# TMAX: Maximum torque (N·m) +MOTOR_LIMIT_PARAMS: dict[MotorType, tuple[float, float, float]] = { + MotorType.O0: (12.57, 33, 14), + MotorType.O1: (12.57, 44, 17), + MotorType.O2: (12.57, 33, 20), + MotorType.O3: (12.57, 33, 60), + MotorType.O4: (12.57, 33, 120), + MotorType.O5: (12.57, 50, 5.5), + MotorType.ELO5: (12.57, 50, 6), + MotorType.O6: (112.5, 50, 36), +} + +# Motor model names +MODEL_NAMES = { + MotorType.O0: "O0", + MotorType.O1: "O1", + MotorType.O2: "O2", + MotorType.O3: "O3", + MotorType.O4: "O4", + MotorType.O5: "O5", + MotorType.ELO5: "ELO5", + MotorType.O6: "O6", +} + +# Motor resolution table (encoder counts per revolution) +MODEL_RESOLUTION = { + "O0": 65536, + "O1": 65536, + "O2": 65536, + "O3": 65536, + "O4": 65536, + "O5": 65536, + "ELO5": 65536, + "O6": 65536, +} + +# CAN baudrates supported by Robstride motors +AVAILABLE_BAUDRATES = [ + 1000000, # 4: 1 mbps (default) +] +DEFAULT_BAUDRATE = 1000000 + +# Default timeout in milliseconds +DEFAULT_TIMEOUT_MS = 0 # disabled by default, otherwise 20000 is 1s + + +# Data that should be normalized +NORMALIZED_DATA = ["Present_Position", "Goal_Position"] + + +# MIT control parameter ranges +MIT_KP_RANGE = (0.0, 500.0) +MIT_KD_RANGE = (0.0, 5.0) + +# CAN frame command IDs +CAN_CMD_ENABLE = 0xFC +CAN_CMD_DISABLE = 0xFD +CAN_CMD_SET_ZERO = 0xFE +CAN_CMD_CLEAR_FAULT = 0xFB + + +CAN_CMD_QUERY_PARAM = 0x33 +CAN_CMD_WRITE_PARAM = 0x55 +CAN_CMD_SAVE_PARAM = 0xAA + +# CAN ID for parameter operations +CAN_PARAM_ID = 0x7FF + + +RUNNING_TIMEOUT = 0.001 +PARAM_TIMEOUT = 0.01 + +STATE_CACHE_TTL_S = 0.02 diff --git a/src/lerobot/scripts/lerobot_setup_can.py b/src/lerobot/scripts/lerobot_setup_can.py index a31727ea4..b28fca44d 100644 --- a/src/lerobot/scripts/lerobot_setup_can.py +++ b/src/lerobot/scripts/lerobot_setup_can.py @@ -152,6 +152,7 @@ def test_motor(bus, motor_id: int, timeout: float, use_fd: bool): ) try: bus.send(disable_msg) + bus.recv(timeout=0.1) # Clear any pending responses except Exception: print(f"Error sending message to motor 0x{motor_id:02X}") From fcabfd32a5a45f0a8210179a6cc8f605815f0dab Mon Sep 17 00:00:00 2001 From: Yuta Nakagawa Date: Tue, 24 Feb 2026 01:11:46 +0900 Subject: [PATCH 108/111] chore(docs): update the document for Phone teleop to clarify how to use the examples (#2991) * update the document for Phone teleope to clarify how to use the examples * Update docs/source/phone_teleop.mdx Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Yuta Nakagawa --------- Signed-off-by: Yuta Nakagawa Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Steven Palma --- docs/source/phone_teleop.mdx | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/docs/source/phone_teleop.mdx b/docs/source/phone_teleop.mdx index 06e524975..678783e7b 100644 --- a/docs/source/phone_teleop.mdx +++ b/docs/source/phone_teleop.mdx @@ -66,12 +66,13 @@ Run on of the examples scripts to teleoperate, record a dataset, replay a datase All scripts assume you configured your robot (e.g., SO-100 follower) and set the correct serial port. -Additionally you need to **copy the urdf of the robot to the examples folder**. For the examples in this tutorial (Using SO100/SO101) it is highly recommended to use the urdf in the [SO-ARM100 repo](https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf) +Additionally you need to **copy the URDF of the robot into the examples folder**. For the examples in this tutorial (using SO100/SO101), copy the `SO101` folder from the [SO-ARM100 repo](https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101) into the `examples/phone_to_so100/` directory, so that the URDF file path becomes `examples/phone_to_so100/SO101/so101_new_calib.urdf`. - Run this example to teleoperate: ```bash - python examples/phone_to_so100/teleoperate.py + cd examples/phone_to_so100 + python teleoperate.py ``` After running the example: @@ -84,19 +85,22 @@ Additionally you can customize mapping or safety limits by editing the processor - Run this example to record a dataset, which saves absolute end effector observations and actions: ```bash - python examples/phone_to_so100/record.py + cd examples/phone_to_so100 + python record.py ``` - Run this example to replay recorded episodes: ```bash - python examples/phone_to_so100/replay.py + cd examples/phone_to_so100 + python replay.py ``` - Run this example to evaluate a pretrained policy: ```bash - python examples/phone_to_so100/evaluate.py + cd examples/phone_to_so100 + python evaluate.py ``` ### Important pipeline steps and options From 7dbbaa3727e8866ce6649e895d152f129a0c6d32 Mon Sep 17 00:00:00 2001 From: Guilherme Miotto Date: Mon, 23 Feb 2026 17:11:55 +0100 Subject: [PATCH 109/111] Small comment fix (#2990) Co-authored-by: Steven Palma --- src/lerobot/policies/smolvla/configuration_smolvla.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lerobot/policies/smolvla/configuration_smolvla.py b/src/lerobot/policies/smolvla/configuration_smolvla.py index c32c8a60e..c696265f2 100644 --- a/src/lerobot/policies/smolvla/configuration_smolvla.py +++ b/src/lerobot/policies/smolvla/configuration_smolvla.py @@ -85,7 +85,7 @@ class SmolVLAConfig(PreTrainedConfig): scheduler_decay_lr: float = 2.5e-6 vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone. - load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights + load_vlm_weights: bool = False # Set to False in case of training the expert from scratch. True when init from pretrained SmolVLA weights add_image_special_tokens: bool = False # Whether to use special image tokens around image features. From 0f44adbeecffca72e9b0f41c58b8222402ff4613 Mon Sep 17 00:00:00 2001 From: Yuan Haokuan <138340416+WilbertYuan@users.noreply.github.com> Date: Tue, 24 Feb 2026 00:51:13 +0800 Subject: [PATCH 110/111] docs: fix HF_USER export command to correctly parse username (#2932) * Fix HF_USER extraction command in documentation Updated command to extract the username from hf auth output. Signed-off-by: Yuan Haokuan <138340416+WilbertYuan@users.noreply.github.com> * Correct HF_USER variable assignment in documentation Fix the variable extraction from hf auth output. Signed-off-by: Yuan Haokuan <138340416+WilbertYuan@users.noreply.github.com> * Update docs/source/il_robots.mdx Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Yuan Haokuan <138340416+WilbertYuan@users.noreply.github.com> --------- Signed-off-by: Yuan Haokuan <138340416+WilbertYuan@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Steven Palma --- docs/source/il_robots.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index 7fc770b0c..bad88f88e 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -165,7 +165,7 @@ huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential Then store your Hugging Face repository name in a variable: ```bash -HF_USER=$(hf auth whoami | head -n 1) +HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}') echo $HF_USER ``` From 7fd71c83a3c5ba496b6a19aad96a048b7c42410f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dominik=20Pa=C4=BEo?= Date: Mon, 23 Feb 2026 20:41:20 +0100 Subject: [PATCH 111/111] docs: add WSL evdev installation note (#2855) Add a note in the installation guide explaining that users on WSL need to install evdev to avoid build issues. See: https://github.com/huggingface/lerobot/issues/2528 Signed-off-by: Steven Palma Co-authored-by: Steven Palma --- docs/source/installation.mdx | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 8cc83843e..a112377c1 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -40,6 +40,13 @@ conda install ffmpeg -c conda-forge > > - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`. +> [!NOTE] +> When installing LeRobot inside WSL (Windows Subsystem for Linux), make sure to install `evdev` with the following command: +> +> ```bash +> conda install evdev -c conda-forge +> ``` + ## Step 3: Install LeRobot 🤗 ### From Source