diff --git a/.cache/calibration/aloha_default/left_follower.json b/.cache/calibration/aloha_default/left_follower.json deleted file mode 100644 index 336c238a0..000000000 --- a/.cache/calibration/aloha_default/left_follower.json +++ /dev/null @@ -1,68 +0,0 @@ -{ - "homing_offset": [ - 2048, - 3072, - 3072, - -1024, - -1024, - 2048, - -2048, - 2048, - -2048 - ], - "drive_mode": [ - 1, - 1, - 1, - 0, - 0, - 1, - 0, - 1, - 0 - ], - "start_pos": [ - 2015, - 3058, - 3061, - 1071, - 1071, - 2035, - 2152, - 2029, - 2499 - ], - "end_pos": [ - -1008, - -1963, - -1966, - 2141, - 2143, - -971, - 3043, - -1077, - 3144 - ], - "calib_mode": [ - "DEGREE", - "DEGREE", - "DEGREE", - "DEGREE", - "DEGREE", - "DEGREE", - "DEGREE", - "DEGREE", - "LINEAR" - ], - "motor_names": [ - "waist", - "shoulder", - "shoulder_shadow", - "elbow", - "elbow_shadow", - "forearm_roll", - "wrist_angle", - "wrist_rotate", - "gripper" - ] -} diff --git a/.cache/calibration/aloha_default/left_leader.json b/.cache/calibration/aloha_default/left_leader.json deleted file mode 100644 index d933f2bab..000000000 --- a/.cache/calibration/aloha_default/left_leader.json +++ /dev/null @@ -1,68 +0,0 @@ -{ - "homing_offset": [ - 2048, - 3072, - 3072, - -1024, - -1024, - 2048, - -2048, - 2048, - -1024 - ], - "drive_mode": [ - 1, - 1, - 1, - 0, - 0, - 1, - 0, - 1, - 0 - ], - "start_pos": [ - 2035, - 3024, - 3019, - 979, - 981, - 1982, - 2166, - 2124, - 1968 - ], - "end_pos": [ - -990, - -2017, - -2015, - 2078, - 2076, - -1030, - 3117, - -1016, - 2556 - ], - "calib_mode": [ - "DEGREE", - "DEGREE", - "DEGREE", - "DEGREE", - "DEGREE", - "DEGREE", - "DEGREE", - "DEGREE", - "LINEAR" - ], - "motor_names": [ - "waist", - "shoulder", - "shoulder_shadow", - "elbow", - "elbow_shadow", - "forearm_roll", - "wrist_angle", - "wrist_rotate", - "gripper" - ] -} diff --git a/.cache/calibration/aloha_default/right_follower.json b/.cache/calibration/aloha_default/right_follower.json deleted file mode 100644 index bc69dfafd..000000000 --- a/.cache/calibration/aloha_default/right_follower.json +++ /dev/null @@ -1,68 +0,0 @@ -{ - "homing_offset": [ - 2048, - 3072, - 3072, - -1024, - -1024, - 2048, - -2048, - 2048, - -2048 - ], - "drive_mode": [ - 1, - 1, - 1, - 0, - 0, - 1, - 0, - 1, - 0 - ], - "start_pos": [ - 2056, - 2895, - 2896, - 1191, - 1190, - 2018, - 2051, - 2056, - 2509 - ], - "end_pos": [ - -1040, - -2004, - -2006, - 2126, - 2127, - -1010, - 3050, - -1117, - 3143 - ], - "calib_mode": [ - "DEGREE", - "DEGREE", - "DEGREE", - "DEGREE", - "DEGREE", - "DEGREE", - "DEGREE", - "DEGREE", - "LINEAR" - ], - "motor_names": [ - "waist", - "shoulder", - "shoulder_shadow", - "elbow", - "elbow_shadow", - "forearm_roll", - "wrist_angle", - "wrist_rotate", - "gripper" - ] -} diff --git a/.cache/calibration/aloha_default/right_leader.json b/.cache/calibration/aloha_default/right_leader.json deleted file mode 100644 index d96d1de9b..000000000 --- a/.cache/calibration/aloha_default/right_leader.json +++ /dev/null @@ -1,68 +0,0 @@ -{ - "homing_offset": [ - 2048, - 3072, - 3072, - -1024, - -1024, - 2048, - -2048, - 2048, - -2048 - ], - "drive_mode": [ - 1, - 1, - 1, - 0, - 0, - 1, - 0, - 1, - 0 - ], - "start_pos": [ - 2068, - 3034, - 3030, - 1038, - 1041, - 1991, - 1948, - 2090, - 1985 - ], - "end_pos": [ - -1025, - -2014, - -2015, - 2058, - 2060, - -955, - 3091, - -940, - 2576 - ], - "calib_mode": [ - "DEGREE", - "DEGREE", - "DEGREE", - "DEGREE", - "DEGREE", - "DEGREE", - "DEGREE", - "DEGREE", - "LINEAR" - ], - "motor_names": [ - "waist", - "shoulder", - "shoulder_shadow", - "elbow", - "elbow_shadow", - "forearm_roll", - "wrist_angle", - "wrist_rotate", - "gripper" - ] -} diff --git a/.github/workflows/build-docker-images.yml b/.github/workflows/build-docker-images.yml index 0cb11d576..20974b85a 100644 --- a/.github/workflows/build-docker-images.yml +++ b/.github/workflows/build-docker-images.yml @@ -40,24 +40,24 @@ jobs: git lfs install - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 + uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0 with: cache-binary: false - name: Check out code - uses: actions/checkout@v4 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: lfs: true persist-credentials: false - name: Login to DockerHub - uses: docker/login-action@v3 + uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_PASSWORD }} - name: Build and Push CPU - uses: docker/build-push-action@v5 + uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0 with: context: . file: ./docker/lerobot-cpu/Dockerfile @@ -78,24 +78,24 @@ jobs: git lfs install - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 + uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0 with: cache-binary: false - name: Check out code - uses: actions/checkout@v4 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: lfs: true persist-credentials: false - name: Login to DockerHub - uses: docker/login-action@v3 + uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_PASSWORD }} - name: Build and Push GPU - uses: docker/build-push-action@v5 + uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0 with: context: . file: ./docker/lerobot-gpu/Dockerfile @@ -110,23 +110,23 @@ jobs: group: aws-general-8-plus steps: - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 + uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0 with: cache-binary: false - name: Check out code - uses: actions/checkout@v4 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: persist-credentials: false - name: Login to DockerHub - uses: docker/login-action@v3 + uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_PASSWORD }} - name: Build and Push GPU dev - uses: docker/build-push-action@v5 + uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0 with: context: . file: ./docker/lerobot-gpu-dev/Dockerfile diff --git a/.github/workflows/build_documentation.yml b/.github/workflows/build_documentation.yml new file mode 100644 index 000000000..884e2e4b5 --- /dev/null +++ b/.github/workflows/build_documentation.yml @@ -0,0 +1,23 @@ +name: Build documentation + +on: + workflow_dispatch: + push: + paths: + - "docs/**" + branches: + - main + - doc-builder* + - v*-release + + +jobs: + build: # zizmor: ignore[excessive-permissions] We follow the same pattern as in Transformers + uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main + with: + commit_sha: ${{ github.sha }} + package: lerobot + additional_args: --not_python_module + secrets: + token: ${{ secrets.HUGGINGFACE_PUSH }} + hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml new file mode 100644 index 000000000..51bab10d5 --- /dev/null +++ b/.github/workflows/build_pr_documentation.yml @@ -0,0 +1,19 @@ +name: Build PR Documentation + +on: + pull_request: + paths: + - "docs/**" + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + build: # zizmor: ignore[excessive-permissions] We follow the same pattern as in Transformers + uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main + with: + commit_sha: ${{ github.event.pull_request.head.sha }} + pr_number: ${{ github.event.number }} + package: lerobot + additional_args: --not_python_module diff --git a/.github/workflows/nightly-tests.yml b/.github/workflows/nightly-tests.yml index adac9f20d..be248b335 100644 --- a/.github/workflows/nightly-tests.yml +++ b/.github/workflows/nightly-tests.yml @@ -33,7 +33,7 @@ jobs: runs-on: group: aws-general-8-plus container: - image: huggingface/lerobot-cpu:latest + image: huggingface/lerobot-cpu:latest # zizmor: ignore[unpinned-images] options: --shm-size "16gb" credentials: username: ${{ secrets.DOCKERHUB_USERNAME }} @@ -60,7 +60,7 @@ jobs: CUDA_VISIBLE_DEVICES: "0" TEST_TYPE: "single_gpu" container: - image: huggingface/lerobot-gpu:latest + image: huggingface/lerobot-gpu:latest # zizmor: ignore[unpinned-images] options: --gpus all --shm-size "16gb" credentials: username: ${{ secrets.DOCKERHUB_USERNAME }} diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml index 332b543c2..1c048c4fe 100644 --- a/.github/workflows/quality.yml +++ b/.github/workflows/quality.yml @@ -33,12 +33,12 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout Repository - uses: actions/checkout@v4 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: persist-credentials: false - name: Set up Python - uses: actions/setup-python@v4 + uses: actions/setup-python@7f4fc3e22c37d6ff65e88745f38bd3157c663f7c # v4.9.1 with: python-version: ${{ env.PYTHON_VERSION }} @@ -64,9 +64,9 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout Repository - uses: actions/checkout@v4 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: persist-credentials: false - name: typos-action - uses: crate-ci/typos@v1.29.10 + uses: crate-ci/typos@db35ee91e80fbb447f33b0e5fbddb24d2a1a884f # v1.29.10 diff --git a/.github/workflows/test-docker-build.yml b/.github/workflows/test-docker-build.yml index c31025645..7a1e93274 100644 --- a/.github/workflows/test-docker-build.yml +++ b/.github/workflows/test-docker-build.yml @@ -35,7 +35,7 @@ jobs: matrix: ${{ steps.set-matrix.outputs.matrix }} steps: - name: Check out code - uses: actions/checkout@v4 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: persist-credentials: false @@ -64,17 +64,17 @@ jobs: docker-file: ${{ fromJson(needs.get_changed_files.outputs.matrix) }} steps: - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 + uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0 with: cache-binary: false - name: Check out code - uses: actions/checkout@v4 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: persist-credentials: false - name: Build Docker image - uses: docker/build-push-action@v5 + uses: docker/build-push-action@ca052bb54ab0790a636c9b5f226502c73d547a25 # v5.4.0 with: file: ${{ matrix.docker-file }} context: . diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d91c53646..8822956cf 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -50,7 +50,7 @@ jobs: env: MUJOCO_GL: egl steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: lfs: true # Ensure LFS files are pulled persist-credentials: false @@ -62,7 +62,7 @@ jobs: sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev - name: Install uv and python - uses: astral-sh/setup-uv@v5 + uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2 with: enable-cache: true version: ${{ env.UV_VERSION }} @@ -85,7 +85,7 @@ jobs: env: MUJOCO_GL: egl steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: lfs: true # Ensure LFS files are pulled persist-credentials: false @@ -94,7 +94,7 @@ jobs: run: sudo apt-get update && sudo apt-get install -y ffmpeg - name: Install uv and python - uses: astral-sh/setup-uv@v5 + uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2 with: enable-cache: true version: ${{ env.UV_VERSION }} @@ -117,7 +117,7 @@ jobs: env: MUJOCO_GL: egl steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: lfs: true # Ensure LFS files are pulled persist-credentials: false @@ -129,7 +129,7 @@ jobs: sudo apt-get install -y libegl1-mesa-dev ffmpeg portaudio19-dev - name: Install uv and python - uses: astral-sh/setup-uv@v5 + uses: astral-sh/setup-uv@d4b2f3b6ecc6e67c4457f6d3e41ec42d3d0fcb86 # v5.4.2 with: enable-cache: true version: ${{ env.UV_VERSION }} diff --git a/.github/workflows/trufflehog.yml b/.github/workflows/trufflehog.yml index 166e05908..704a3baaa 100644 --- a/.github/workflows/trufflehog.yml +++ b/.github/workflows/trufflehog.yml @@ -24,12 +24,12 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: fetch-depth: 0 persist-credentials: false - name: Secret Scanning - uses: trufflesecurity/trufflehog@main + uses: trufflesecurity/trufflehog@90694bf9af66e7536abc5824e7a87246dbf933cb # v3.88.35 with: extra_args: --only-verified diff --git a/.github/workflows/upload_pr_documentation.yml b/.github/workflows/upload_pr_documentation.yml new file mode 100644 index 000000000..32665930b --- /dev/null +++ b/.github/workflows/upload_pr_documentation.yml @@ -0,0 +1,16 @@ +name: Upload PR Documentation + +on: # zizmor: ignore[dangerous-triggers] We follow the same pattern as in Transformers + workflow_run: + workflows: [ "Build PR Documentation" ] + types: + - completed + +jobs: + build: # zizmor: ignore[excessive-permissions] We follow the same pattern as in Transformers + uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main + with: + package_name: lerobot + secrets: + hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} + comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a9e4a8565..23a180046 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -269,9 +269,6 @@ Follow these steps to start contributing: the PR as a draft PR. These are useful to avoid duplicated work, and to differentiate it from PRs ready to be merged; 4. Make sure existing tests pass; - ### Tests diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 000000000..d2608ca0a --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +include lerobot/templates/lerobot_modelcard_template.md +include lerobot/common/datasets/card_template.md diff --git a/README.md b/README.md index 679fa7bfe..42e0ee2c5 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ pip install -e . ``` > **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run: -`sudo apt-get install cmake build-essential python-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg) +`sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg) For simulations, 🤗 LeRobot comes with gymnasium environments that can be installed as extras: - [aloha](https://github.com/huggingface/gym-aloha) diff --git a/benchmarks/video/capture_camera_feed.py b/benchmarks/video/capture_camera_feed.py old mode 100644 new mode 100755 diff --git a/benchmarks/video/run_video_benchmark.py b/benchmarks/video/run_video_benchmark.py index a73b77b13..fed394d15 100644 --- a/benchmarks/video/run_video_benchmark.py +++ b/benchmarks/video/run_video_benchmark.py @@ -418,7 +418,7 @@ if __name__ == "__main__": "--vcodec", type=str, nargs="*", - default=["libx264", "libx265", "libsvtav1"], + default=["libx264", "hevc", "libsvtav1"], help="Video codecs to be tested", ) parser.add_argument( @@ -448,7 +448,7 @@ if __name__ == "__main__": # nargs="*", # default=[0, 1], # help="Use the fastdecode tuning option. 0 disables it. " - # "For libx264 and libx265, only 1 is possible. " + # "For libx264 and libx265/hevc, only 1 is possible. " # "For libsvtav1, 1, 2 or 3 are possible values with a higher number meaning a faster decoding optimization", # ) parser.add_argument( diff --git a/docker/lerobot-cpu/Dockerfile b/docker/lerobot-cpu/Dockerfile index 13a45d249..85c31ac1a 100644 --- a/docker/lerobot-cpu/Dockerfile +++ b/docker/lerobot-cpu/Dockerfile @@ -22,7 +22,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ COPY . /lerobot WORKDIR /lerobot RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \ - && /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" \ + && /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, smolvla]" \ --extra-index-url https://download.pytorch.org/whl/cpu # Execute in bash shell rather than python diff --git a/docker/lerobot-gpu-dev/Dockerfile b/docker/lerobot-gpu-dev/Dockerfile index 561a7cff6..4d25b2550 100644 --- a/docker/lerobot-gpu-dev/Dockerfile +++ b/docker/lerobot-gpu-dev/Dockerfile @@ -14,7 +14,7 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ tcpdump sysstat screen tmux \ libglib2.0-0 libgl1-mesa-glx libegl1-mesa \ speech-dispatcher portaudio19-dev libgeos-dev \ - python${PYTHON_VERSION} python${PYTHON_VERSION}-venv \ + python${PYTHON_VERSION} python${PYTHON_VERSION}-venv python${PYTHON_VERSION}-dev \ && apt-get clean && rm -rf /var/lib/apt/lists/* # Install ffmpeg build dependencies. See: diff --git a/docker/lerobot-gpu/Dockerfile b/docker/lerobot-gpu/Dockerfile index 642a8ded6..746ea29b7 100644 --- a/docker/lerobot-gpu/Dockerfile +++ b/docker/lerobot-gpu/Dockerfile @@ -21,4 +21,4 @@ RUN apt-get update && apt-get install -y --no-install-recommends \ COPY . /lerobot WORKDIR /lerobot RUN /opt/venv/bin/pip install --upgrade --no-cache-dir pip \ - && /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel]" + && /opt/venv/bin/pip install --no-cache-dir ".[test, aloha, xarm, pusht, dynamixel, smolvla]" diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 000000000..275fee46b --- /dev/null +++ b/docs/README.md @@ -0,0 +1,137 @@ + + +# Generating the documentation + +To generate the documentation, you first have to build it. Several packages are necessary to build the doc, +you can install them with the following command, at the root of the code repository: + +```bash +pip install -e ".[docs]" +``` + +You will also need `nodejs`. Please refer to their [installation page](https://nodejs.org/en/download) + +--- +**NOTE** + +You only need to generate the documentation to inspect it locally (if you're planning changes and want to +check how they look before committing for instance). You don't have to `git commit` the built documentation. + +--- + +## Building the documentation + +Once you have setup the `doc-builder` and additional packages, you can generate the documentation by +typing the following command: + +```bash +doc-builder build lerobot docs/source/ --build_dir ~/tmp/test-build +``` + +You can adapt the `--build_dir` to set any temporary folder that you prefer. This command will create it and generate +the MDX files that will be rendered as the documentation on the main website. You can inspect them in your favorite +Markdown editor. + +## Previewing the documentation + +To preview the docs, first install the `watchdog` module with: + +```bash +pip install watchdog +``` + +Then run the following command: + +```bash +doc-builder preview lerobot docs/source/ +``` + +The docs will be viewable at [http://localhost:3000](http://localhost:3000). You can also preview the docs once you have opened a PR. You will see a bot add a comment to a link where the documentation with your changes lives. + +--- +**NOTE** + +The `preview` command only works with existing doc files. When you add a completely new file, you need to update `_toctree.yml` & restart `preview` command (`ctrl-c` to stop it & call `doc-builder preview ...` again). + +--- + +## Adding a new element to the navigation bar + +Accepted files are Markdown (.md). + +Create a file with its extension and put it in the source directory. You can then link it to the toc-tree by putting +the filename without the extension in the [`_toctree.yml`](https://github.com/huggingface/lerobot/blob/main/docs/source/_toctree.yml) file. + +## Renaming section headers and moving sections + +It helps to keep the old links working when renaming the section header and/or moving sections from one document to another. This is because the old links are likely to be used in Issues, Forums, and Social media and it'd make for a much more superior user experience if users reading those months later could still easily navigate to the originally intended information. + +Therefore, we simply keep a little map of moved sections at the end of the document where the original section was. The key is to preserve the original anchor. + +So if you renamed a section from: "Section A" to "Section B", then you can add at the end of the file: + +``` +Sections that were moved: + +[ Section A ] +``` +and of course, if you moved it to another file, then: + +``` +Sections that were moved: + +[ Section A ] +``` + +Use the relative style to link to the new file so that the versioned docs continue to work. + +For an example of a rich moved sections set please see the very end of [the transformers Trainer doc](https://github.com/huggingface/transformers/blob/main/docs/source/en/main_classes/trainer.md). + +### Adding a new tutorial + +Adding a new tutorial or section is done in two steps: + +- Add a new file under `./source`. This file can either be ReStructuredText (.rst) or Markdown (.md). +- Link that file in `./source/_toctree.yml` on the correct toc-tree. + +Make sure to put your new file under the proper section. If you have a doubt, feel free to ask in a Github Issue or PR. + +### Writing source documentation + +Values that should be put in `code` should either be surrounded by backticks: \`like so\`. Note that argument names +and objects like True, None or any strings should usually be put in `code`. + +#### Writing a multi-line code block + +Multi-line code blocks can be useful for displaying examples. They are done between two lines of three backticks as usual in Markdown: + + +```` +``` +# first line of code +# second line +# etc +``` +```` + +#### Adding an image + +Due to the rapidly growing repository, it is important to make sure that no files that would significantly weigh down the repository are added. This includes images, videos, and other non-text files. We prefer to leverage a hf.co hosted `dataset` like +the ones hosted on [`hf-internal-testing`](https://huggingface.co/hf-internal-testing) in which to place these files and reference +them by URL. We recommend putting them in the following dataset: [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images). +If an external contribution, feel free to add the images to your PR and ask a Hugging Face member to migrate your images +to this dataset. diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml new file mode 100644 index 000000000..ea80e8257 --- /dev/null +++ b/docs/source/_toctree.yml @@ -0,0 +1,44 @@ +- sections: + - local: index + title: LeRobot + - local: installation + title: Installation + title: Get started +- sections: + - local: il_robots + title: Imitation Learning for Robots + - local: il_sim + title: Imitation Learning in Sim + - local: cameras + title: Cameras + - local: integrate_hardware + title: Bring Your Own Hardware + - local: hilserl + title: Train a Robot with RL + - local: hilserl_sim + title: Train RL in Simulation + title: "Tutorials" +- sections: + - local: smolvla + title: Finetune SmolVLA + title: "Policies" +- sections: + - local: so101 + title: SO-101 + - local: so100 + title: SO-100 + - local: koch + title: Koch v1.1 + - local: lekiwi + title: LeKiwi + title: "Robots" +- sections: + - local: notebooks + title: Notebooks + title: "Resources" +- sections: + - local: contributing + title: Contribute to LeRobot + - local: backwardcomp + title: Backward compatibility + title: "About" diff --git a/docs/source/backwardcomp.mdx b/docs/source/backwardcomp.mdx new file mode 100644 index 000000000..555239170 --- /dev/null +++ b/docs/source/backwardcomp.mdx @@ -0,0 +1,82 @@ +# Backward compatibility + +## Hardware API redesign + +PR [#777](https://github.com/huggingface/lerobot/pull/777) improves the LeRobot calibration but is **not backward-compatible**. Below is a overview of what changed and how you can continue to work with datasets created before this pull request. + +### What changed? + +| | Before PR #777 | After PR #777 | +| --------------------------------- | ------------------------------------------------- | --------------------------------------------------------------------------- | +| **Joint range** | Degrees `-180...180°` | **Normalised range** Joints: `–100...100` Gripper: `0...100` | +| **Zero position (SO100 / SO101)** | Arm fully extended horizontally | **In middle of the range for each joint** | +| **Boundary handling** | Software safeguards to detect ±180 ° wrap-arounds | No wrap-around logic needed due to mid-range zero | + +--- + +### Impact on existing datasets + +* Recorded trajectories created **before** PR #777 will replay incorrectly if loaded directly: + * Joint angles are offset and incorrectly normalized. +* Any models directly finetuned or trained on the old data will need their inputs and outputs converted. + +### Using datasets made with the previous calibration system +We provide a migration example script for replaying an episode recorded with the previous calibration here: `examples/backward_compatibility/replay.py`. +Below we take you through the modifications that are done in the example script to make the previous calibration datasets work. + +```diff ++ key = f"{name.removeprefix('main_')}.pos" + action[key] = action_array[i].item() ++ action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90) ++ action["elbow_flex.pos"] -= 90 +``` + +Let's break this down. +New codebase uses `.pos` suffix for the position observations and we have removed `main_` prefix: +```python +key = f"{name.removeprefix('main_')}.pos" +``` + +For `"shoulder_lift"` (id = 2), the 0 position is changed by -90 degrees and the direction is reversed compared to old calibration/code. +```python +action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90) +``` +For `"elbow_flex"` (id = 3), the 0 position is changed by -90 degrees compared to old calibration/code. +```python +action["elbow_flex.pos"] -= 90 +``` + +To use degrees normalization we then set the `--robot.use_degrees` option to `true`. +```diff +python examples/backward_compatibility/replay.py \ + --robot.type=so101_follower \ + --robot.port=/dev/tty.usbmodem5A460814411 \ + --robot.id=blue \ ++ --robot.use_degrees=true \ + --dataset.repo_id=my_dataset_id \ + --dataset.episode=0 +``` + +### Using policies trained with the previous calibration system + +Policies output actions in the same format as the datasets (`torch.Tensors`). Therefore, the same transformations should be applied. + +To find these transformations, we recommend to first try and and replay an episode of the dataset your policy was trained on using the section above. +Then, add these same transformations on your inference script (shown here in the `record.py` script): +```diff +action_values = predict_action( + observation_frame, + policy, + get_safe_torch_device(policy.config.device), + policy.config.use_amp, + task=single_task, + robot_type=robot.robot_type, + ) + action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)} + ++ action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90) ++ action["elbow_flex.pos"] -= 90 + robot.send_action(action) +``` + +If you have questions or run into migration issues, feel free to ask them on [Discord](https://discord.gg/s3KuuzsPFb) diff --git a/docs/source/cameras.mdx b/docs/source/cameras.mdx new file mode 100644 index 000000000..d8a49c1ee --- /dev/null +++ b/docs/source/cameras.mdx @@ -0,0 +1,173 @@ +# Cameras + +LeRobot offers multiple options for video capture, including phone cameras, built-in laptop cameras, external webcams, and Intel RealSense cameras. To efficiently record frames from most cameras, you can use either the `OpenCVCamera` or `RealSenseCamera` class. For additional compatibility details on the `OpenCVCamera` class, refer to the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html). + +### Finding your camera + +To instantiate a camera, you need a camera identifier. This identifier might change if you reboot your computer or re-plug your camera, a behavior mostly dependant on your operating system. + +To find the camera indices of the cameras plugged into your system, run the following script: +```bash +python lerobot/find_cameras.py opencv # or realsense for Intel Realsense cameras +``` + +The output will look something like this if you have two cameras connected: +``` +--- Detected Cameras --- +Camera #0: + Name: OpenCV Camera @ 0 + Type: OpenCV + Id: 0 + Backend api: AVFOUNDATION + Default stream profile: + Format: 16.0 + Width: 1920 + Height: 1080 + Fps: 15.0 +-------------------- +(more cameras ...) +``` + +> [!WARNING] +> When using Intel RealSense cameras in `macOS`, you could get this [error](https://github.com/IntelRealSense/librealsense/issues/12307): `Error finding RealSense cameras: failed to set power state`, this can be solved by running the same command with `sudo` permissions. Note that using RealSense cameras in `macOS` is unstable. + + +## Use Cameras + +Below are two examples, demonstrating how to work with the API. + +- **Asynchronous frame capture** using an OpenCV-based camera +- **Color and depth capture** using an Intel RealSense camera + + + + + +```python +from lerobot.common.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.common.cameras.opencv.camera_opencv import OpenCVCamera +from lerobot.common.cameras.configs import ColorMode, Cv2Rotation + +# Construct an `OpenCVCameraConfig` with your desired FPS, resolution, color mode, and rotation. +config = OpenCVCameraConfig( + index_or_path=0, + fps=15, + width=1920, + height=1080, + color_mode=ColorMode.RGB, + rotation=Cv2Rotation.NO_ROTATION +) + +# Instantiate and connect an `OpenCVCamera`, performing a warm-up read (default). +camera = OpenCVCamera(config) +camera.connect() + +# Read frames asynchronously in a loop via `async_read(timeout_ms)` +try: + for i in range(10): + frame = camera.async_read(timeout_ms=200) + print(f"Async frame {i} shape:", frame.shape) +finally: + camera.disconnect() +``` + + + + +```python +from lerobot.common.cameras.realsense.configuration_realsense import RealSenseCameraConfig +from lerobot.common.cameras.realsense.camera_realsense import RealSenseCamera +from lerobot.common.cameras.configs import ColorMode, Cv2Rotation + +# Create a `RealSenseCameraConfig` specifying your camera’s serial number and enabling depth. +config = RealSenseCameraConfig( + serial_number_or_name="233522074606", + fps=15, + width=640, + height=480, + color_mode=ColorMode.RGB, + use_depth=True, + rotation=Cv2Rotation.NO_ROTATION +) + +# Instantiate and connect a `RealSenseCamera` with warm-up read (default). +camera = RealSenseCamera(config) +camera.connect() + +# Capture a color frame via `read()` and a depth map via `read_depth()`. +try: + color_frame = camera.read() + depth_map = camera.read_depth() + print("Color frame shape:", color_frame.shape) + print("Depth map shape:", depth_map.shape) +finally: + camera.disconnect() +``` + + + + +## Use your phone + + + +To use your iPhone as a camera on macOS, enable the Continuity Camera feature: +- Ensure your Mac is running macOS 13 or later, and your iPhone is on iOS 16 or later. +- Sign in both devices with the same Apple ID. +- Connect your devices with a USB cable or turn on Wi-Fi and Bluetooth for a wireless connection. + +For more details, visit [Apple support](https://support.apple.com/en-gb/guide/mac-help/mchl77879b8a/mac). + +Your iPhone should be detected automatically when running the camera setup script in the next section. + + + + +If you want to use your phone as a camera on Linux, follow these steps to set up a virtual camera + +1. *Install `v4l2loopback-dkms` and `v4l-utils`*. Those packages are required to create virtual camera devices (`v4l2loopback`) and verify their settings with the `v4l2-ctl` utility from `v4l-utils`. Install them using: +```python +sudo apt install v4l2loopback-dkms v4l-utils +``` +2. *Install [DroidCam](https://droidcam.app) on your phone*. This app is available for both iOS and Android. +3. *Install [OBS Studio](https://obsproject.com)*. This software will help you manage the camera feed. Install it using [Flatpak](https://flatpak.org): +```python +flatpak install flathub com.obsproject.Studio +``` +4. *Install the DroidCam OBS plugin*. This plugin integrates DroidCam with OBS Studio. Install it with: +```python +flatpak install flathub com.obsproject.Studio.Plugin.DroidCam +``` +5. *Start OBS Studio*. Launch with: +```python +flatpak run com.obsproject.Studio +``` +6. *Add your phone as a source*. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480`. +7. *Adjust resolution settings*. In OBS Studio, go to `File > Settings > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it in. +8. *Start virtual camera*. In OBS Studio, follow the instructions [here](https://obsproject.com/kb/virtual-camera-guide). +9. *Verify the virtual camera setup*. Use `v4l2-ctl` to list the devices: +```python +v4l2-ctl --list-devices +``` +You should see an entry like: +``` +VirtualCam (platform:v4l2loopback-000): +/dev/video1 +``` +10. *Check the camera resolution*. Use `v4l2-ctl` to ensure that the virtual camera output resolution is `640x480`. Change `/dev/video1` to the port of your virtual camera from the output of `v4l2-ctl --list-devices`. +```python +v4l2-ctl -d /dev/video1 --get-fmt-video +``` +You should see an entry like: +``` +>>> Format Video Capture: +>>> Width/Height : 640/480 +>>> Pixel Format : 'YUYV' (YUYV 4:2:2) +``` + +Troubleshooting: If the resolution is not correct you will have to delete the Virtual Camera port and try again as it cannot be changed. + +If everything is set up correctly, you can proceed with the rest of the tutorial. + + + diff --git a/docs/source/contributing.md b/docs/source/contributing.md new file mode 120000 index 000000000..f939e75f2 --- /dev/null +++ b/docs/source/contributing.md @@ -0,0 +1 @@ +../../CONTRIBUTING.md \ No newline at end of file diff --git a/docs/source/hilserl.mdx b/docs/source/hilserl.mdx new file mode 100644 index 000000000..149b25c68 --- /dev/null +++ b/docs/source/hilserl.mdx @@ -0,0 +1,547 @@ +# HIL-SERL Real Robot Training Workflow Guide + +In this tutorial you will go through the full Human-in-the-Loop Sample-Efficient Reinforcement Learning (HIL-SERL) workflow using LeRobot. You will master training a policy with RL on a real robot in just a few hours. + +HIL-SERL is a sample-efficient reinforcement learning algorithm that combines human demonstrations with online learning and human interventions. The approach starts from a small set of human demonstrations, uses them to train a reward classifier, and then employs an actor-learner architecture where humans can intervene during policy execution to guide exploration and correct unsafe behaviors. In this tutorial, you'll use a gamepad to provide interventions and control the robot during the learning process. + +It combines three key ingredients: + 1. **Offline demonstrations & reward classifier:** a handful of human-teleop episodes plus a vision-based success detector give the policy a shaped starting point. + 2. **On-robot actor / learner loop with human interventions:** a distributed Soft Actor Critic (SAC) learner updates the policy while an actor explores on the physical robot; the human can jump in at any time to correct dangerous or unproductive behaviour. + 3. **Safety & efficiency tools:** joint/end-effector (EE) bounds, crop region of interest (ROI) preprocessing and WandB monitoring keep the data useful and the hardware safe. + +Together these elements let HIL-SERL reach near-perfect task success and faster cycle times than imitation-only baselines. + +

+ HIL-SERL workflow +

+ +

HIL-SERL workflow, Luo et al. 2024

+ +This guide provides step-by-step instructions for training a robot policy using LeRobot's HilSerl implementation to train on a real robot. + +## What do I need? + +- A gamepad (recommended) or keyboard to control the robot +- A Nvidia GPU +- A real robot with a follower and leader arm (optional if you use the keyboard or the gamepad) + +## What kind of tasks can I train? + +One can use HIL-SERL to train on a variety of manipulation tasks. Some recommendations: +- Start with a simple task to understand how the system works. + - Push cube to a goal region + - Pick and lift cube with the gripper +- Avoid extremely long horizon tasks. Focus on tasks that can be completed in 5-10 seconds. +- Once you have a good idea of how the system works, you can try more complex tasks and longer horizons. + - Pick and place cube + - Bimanual tasks to pick objects with two arms + - Hand-over tasks to transfer objects from one arm to another + - Go crazy! + +## Install LeRobot with HIL-SERL + +To install LeRobot with HIL-SERL, you need to install the `hilserl` extra. + +```bash +pip install -e ".[hilserl]" +``` + +## Real Robot Training Workflow + +### Understanding Configuration + +The training process begins with proper configuration for the HILSerl environment. The configuration class of interest is `HILSerlRobotEnvConfig` in `lerobot/common/envs/configs.py`. Which is defined as: + +```python +class HILSerlRobotEnvConfig(EnvConfig): + robot: RobotConfig | None = None # Main robot agent (defined in `lerobot/common/robots`) + teleop: TeleoperatorConfig | None = None # Teleoperator agent, e.g., gamepad or leader arm, (defined in `lerobot/common/teleoperators`) + wrapper: EnvTransformConfig | None = None # Environment wrapper settings; check `lerobot/scripts/server/gym_manipulator.py` + fps: int = 10 # Control frequency + name: str = "real_robot" # Environment name + mode: str = None # "record", "replay", or None (for training) + repo_id: str | None = None # LeRobot dataset repository ID + dataset_root: str | None = None # Local dataset root (optional) + task: str = "" # Task identifier + num_episodes: int = 10 # Number of episodes for recording + episode: int = 0 # episode index for replay + device: str = "cuda" # Compute device + push_to_hub: bool = True # Whether to push the recorded datasets to Hub + pretrained_policy_name_or_path: str | None = None # For policy loading + reward_classifier_pretrained_path: str | None = None # For reward model + number_of_steps_after_success: int = 0 # For reward classifier, collect more positive examples after a success to train a classifier +``` + + +### Finding Robot Workspace Bounds + +Before collecting demonstrations, you need to determine the appropriate operational bounds for your robot. + +This helps simplify the problem of learning on the real robot in two ways: 1) by limiting the robot's operational space to a specific region that solves the task and avoids unnecessary or unsafe exploration, and 2) by allowing training in end-effector space rather than joint space. Empirically, learning in joint space for reinforcement learning in manipulation is often a harder problem - some tasks are nearly impossible to learn in joint space but become learnable when the action space is transformed to end-effector coordinates. + +**Using find_joint_limits.py** + +This script helps you find the safe operational bounds for your robot's end-effector. Given that you have a follower and leader arm, you can use the script to find the bounds for the follower arm that will be applied during training. +Bounding the action space will reduce the redundant exploration of the agent and guarantees safety. + +```bash +python -m lerobot.scripts.find_joint_limits \ + --robot.type=so100_follower \ + --robot.port=/dev/tty.usbmodem58760431541 \ + --robot.id=black \ + --teleop.type=so100_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ + --teleop.id=blue +``` + +**Workflow** + +1. Run the script and move the robot through the space that solves the task +2. The script will record the minimum and maximum end-effector positions and the joint angles and prints them to the console, for example: + ``` + Max ee position [0.2417 0.2012 0.1027] + Min ee position [0.1663 -0.0823 0.0336] + Max joint positions [-20.0, -20.0, -20.0, -20.0, -20.0, -20.0] + Min joint positions [50.0, 50.0, 50.0, 50.0, 50.0, 50.0] + ``` +3. Use these values in the configuration of your teleoperation device (TeleoperatorConfig) under the `end_effector_bounds` field + +**Example Configuration** + +```json +"end_effector_bounds": { + "max": [0.24, 0.20, 0.10], + "min": [0.16, -0.08, 0.03] +} +``` + +### Collecting Demonstrations + +With the bounds defined, you can safely collect demonstrations for training. Training RL with off-policy algorithm allows us to use offline datasets collected in order to improve the efficiency of the learning process. + +**Setting Up Record Mode** + +Create a configuration file for recording demonstrations (or edit an existing one like [env_config_so100.json](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_so100.json)): + +1. Set `mode` to `"record"` +2. Specify a unique `repo_id` for your dataset (e.g., "username/task_name") +3. Set `num_episodes` to the number of demonstrations you want to collect +4. Set `crop_params_dict` to `null` initially (we'll determine crops later) +5. Configure `robot`, `cameras`, and other hardware settings + +Example configuration section: +```json +"mode": "record", +"repo_id": "username/pick_lift_cube", +"dataset_root": null, +"task": "pick_and_lift", +"num_episodes": 15, +"episode": 0, +"push_to_hub": true +``` + +### Using a Teleoperation Device + +Along with your robot, you will need a teleoperation device to control it in order to collect datasets of your task and perform interventions during the online training. +We support using a gamepad or a keyboard or the leader arm of the robot. + +HIL-Serl learns actions in the end-effector space of the robot. Therefore, the teleoperation will control the end-effector's x,y,z displacements. + +For that we need to define a version of the robot that takes actions in the end-effector space. Check the robot class `SO100FollowerEndEffector` and its configuration `SO100FollowerEndEffectorConfig` for the default parameters related to the end-effector space. + +```python +class SO100FollowerEndEffectorConfig(SO100FollowerConfig): + """Configuration for the SO100FollowerEndEffector robot.""" + + # Default bounds for the end-effector position (in meters) + end_effector_bounds: dict[str, list[float]] = field( # bounds for the end-effector in x,y,z direction + default_factory=lambda: { + "min": [-1.0, -1.0, -1.0], # min x, y, z + "max": [1.0, 1.0, 1.0], # max x, y, z + } + ) + + max_gripper_pos: float = 50 # maximum gripper position that the gripper will be open at + + end_effector_step_sizes: dict[str, float] = field( # maximum step size for the end-effector in x,y,z direction + default_factory=lambda: { + "x": 0.02, + "y": 0.02, + "z": 0.02, + } + ) +``` + +The `Teleoperator` defines the teleoperation device. You can check the list of available teleoperators in `lerobot/common/teleoperators`. + +**Setting up the Gamepad** + +The gamepad provides a very convenient way to control the robot and the episode state. + +To setup the gamepad, you need to set the `control_mode` to `"gamepad"` and define the `teleop` section in the configuration file. + +```json + "teleop": { + "type": "gamepad", + "use_gripper": true + }, +``` + +

+ Figure shows the control mappings on a Logitech gamepad. +

+

Gamepad button mapping for robot control and episode management

+ +**Setting up the SO101 leader** + +The SO101 leader arm has reduced gears that allows it to move and track the follower arm during exploration. Therefore, taking over is much smoother than the gearless SO100. + +To setup the SO101 leader, you need to set the `control_mode` to `"leader"` and define the `teleop` section in the configuration file. + +```json + "teleop": { + "type": "so101_leader", + "port": "/dev/tty.usbmodem585A0077921", # check your port number + "use_degrees": true + }, +``` + +In order to annotate the success/failure of the episode, **you will need** to use a keyboard to press `s` for success, `esc` for failure. +During the online training, press `space` to take over the policy and `space` again to give the control back to the policy. + +
+Video: SO101 leader teleoperation + +
+ +
+ +

SO101 leader teleoperation example, the leader tracks the follower, press `space` to intervene

+
+ +**Recording Demonstrations** + +Start the recording process, an example of the config file can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_so100.json): + +```bash +python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/env_config_so100.json +``` + +During recording: +1. The robot will reset to the initial position defined in the configuration file `fixed_reset_joint_positions` +2. Complete the task successfully +3. The episode ends with a reward of 1 when you press the "success" button +4. If the time limit is reached, or the fail button is pressed, the episode ends with a reward of 0 +5. You can rerecord an episode by pressing the "rerecord" button +6. The process automatically continues to the next episode +7. After recording all episodes, the dataset is pushed to the Hugging Face Hub (optional) and saved locally + + +### Processing the Dataset + +After collecting demonstrations, process them to determine optimal camera crops. +Reinforcement learning is sensitive to background distractions, so it is important to crop the images to the relevant workspace area. + +Visual RL algorithms learn directly from pixel inputs, making them vulnerable to irrelevant visual information. Background elements like changing lighting, shadows, people moving, or objects outside the workspace can confuse the learning process. Good ROI selection should: +- Include only the essential workspace where the task happens +- Capture the robot's end-effector and all objects involved in the task +- Exclude unnecessary background elements and distractions + +Note: If you already know the crop parameters, you can skip this step and just set the `crop_params_dict` in the configuration file during recording. + +**Determining Crop Parameters** + +Use the `crop_dataset_roi.py` script to interactively select regions of interest in your camera images: + +```bash +python lerobot/scripts/rl/crop_dataset_roi.py --repo-id username/pick_lift_cube +``` + +1. For each camera view, the script will display the first frame +2. Draw a rectangle around the relevant workspace area +3. Press 'c' to confirm the selection +4. Repeat for all camera views +5. The script outputs cropping parameters and creates a new cropped dataset + +Example output: +``` +Selected Rectangular Regions of Interest (top, left, height, width): +observation.images.side: [180, 207, 180, 200] +observation.images.front: [180, 250, 120, 150] +``` + +

+ +

+ +

Interactive cropping tool for selecting regions of interest

+ + +**Updating Configuration** + +Add these crop parameters to your training configuration: + +```json +"crop_params_dict": { + "observation.images.side": [180, 207, 180, 200], + "observation.images.front": [180, 250, 120, 150] +}, +"resize_size": [128, 128] +``` + +**Recommended image resolution** + +Most vision-based policies have been validated on square inputs of either **128×128** (default) or **64×64** pixels. We therefore advise setting the resize_size parameter to [128, 128] – or [64, 64] if you need to save GPU memory and bandwidth. Other resolutions are possible but have not been extensively tested. + + +### Training a Reward Classifier + +The reward classifier plays an important role in the HIL-SERL workflow by automating reward assignment and automatically detecting episode success. Instead of manually defining reward functions or relying on human feedback for every timestep, the reward classifier learns to predict success/failure from visual observations. This enables the RL algorithm to learn efficiently by providing consistent and automated reward signals based on the robot's camera inputs. + +This guide explains how to train a reward classifier for human-in-the-loop reinforcement learning implementation of LeRobot. Reward classifiers learn to predict the reward value given a state which can be used in an RL setup to train a policy. + +**Note**: Training a reward classifier is optional. You can start the first round of RL experiments by annotating the success manually with your gamepad or keyboard device. + +The reward classifier implementation in `modeling_classifier.py` uses a pretrained vision model to process the images. It can output either a single value for binary rewards to predict success/fail cases or multiple values for multi-class settings. + +**Collecting a Dataset for the reward classifier** + +Before training, you need to collect a dataset with labeled examples. The `record_dataset` function in `gym_manipulator.py` enables the process of collecting a dataset of observations, actions, and rewards. + +To collect a dataset, you need to modify some parameters in the environment configuration based on HILSerlRobotEnvConfig. + +```bash +python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/reward_classifier_train_config.json +``` + +**Key Parameters for Data Collection** + +- **mode**: set it to `"record"` to collect a dataset +- **repo_id**: `"hf_username/dataset_name"`, name of the dataset and repo on the hub +- **num_episodes**: Number of episodes to record +- **number_of_steps_after_success**: Number of additional frames to record after a success (reward=1) is detected +- **fps**: Number of frames per second to record +- **push_to_hub**: Whether to push the dataset to the hub + +The `number_of_steps_after_success` parameter is crucial as it allows you to collect more positive examples. When a success is detected, the system will continue recording for the specified number of steps while maintaining the reward=1 label. Otherwise, there won't be enough states in the dataset labeled to 1 to train a good classifier. + +Example configuration section for data collection: + +```json +{ + "mode": "record", + "repo_id": "hf_username/dataset_name", + "dataset_root": "data/your_dataset", + "num_episodes": 20, + "push_to_hub": true, + "fps": 10, + "number_of_steps_after_success": 15 +} +``` + +**Reward Classifier Configuration** + +The reward classifier is configured using `configuration_classifier.py`. Here are the key parameters: + +- **model_name**: Base model architecture (e.g., we mainly use `"helper2424/resnet10"`) +- **model_type**: `"cnn"` or `"transformer"` +- **num_cameras**: Number of camera inputs +- **num_classes**: Number of output classes (typically 2 for binary success/failure) +- **hidden_dim**: Size of hidden representation +- **dropout_rate**: Regularization parameter +- **learning_rate**: Learning rate for optimizer + +Example configuration for training the [reward classifier](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/reward_classifier_train_config.json): + +```json +{ + "policy": { + "type": "reward_classifier", + "model_name": "helper2424/resnet10", + "model_type": "cnn", + "num_cameras": 2, + "num_classes": 2, + "hidden_dim": 256, + "dropout_rate": 0.1, + "learning_rate": 1e-4, + "device": "cuda", + "use_amp": true, + "input_features": { + "observation.images.front": { + "type": "VISUAL", + "shape": [3, 128, 128] + }, + "observation.images.side": { + "type": "VISUAL", + "shape": [3, 128, 128] + } + } + } +} +``` + +**Training the Classifier** + +To train the classifier, use the `train.py` script with your configuration: + +```bash +python lerobot/scripts/train.py --config_path path/to/reward_classifier_train_config.json +``` + +**Deploying and Testing the Model** + +To use your trained reward classifier, configure the `HILSerlRobotEnvConfig` to use your model: + +```python +env_config = HILSerlRobotEnvConfig( + reward_classifier_pretrained_path="path_to_your_pretrained_trained_model", + # Other environment parameters +) +``` +or set the argument in the json config file. + +```json +{ + "reward_classifier_pretrained_path": "path_to_your_pretrained_model" +} +``` + +Run `gym_manipulator.py` to test the model. +```bash +python lerobot/scripts/rl/gym_manipulator.py --config_path path/to/env_config.json +``` + +The reward classifier will automatically provide rewards based on the visual input from the robot's cameras. + +**Example Workflow for training the reward classifier** + +1. **Create the configuration files**: + Create the necessary json configuration files for the reward classifier and the environment. Check the examples [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/tree/main). + +2. **Collect a dataset**: + ```bash + python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/env_config.json + ``` + +3. **Train the classifier**: + ```bash + python lerobot/scripts/train.py --config_path lerobot/configs/reward_classifier_train_config.json + ``` + +4. **Test the classifier**: + ```bash + python lerobot/scripts/rl/gym_manipulator.py --config_path lerobot/configs/env_config.json + ``` + +### Training with Actor-Learner + +The LeRobot system uses a distributed actor-learner architecture for training. This architecture decouples robot interactions from the learning process, allowing them to run concurrently without blocking each other. The actor server handles robot observations and actions, sending interaction data to the learner server. The learner server performs gradient descent and periodically updates the actor's policy weights. You will need to start two processes: a learner and an actor. + +**Configuration Setup** + +Create a training configuration file (example available [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/train_config_hilserl_so100.json)). The training config is based on the main `TrainRLServerPipelineConfig` class in `lerobot/configs/train.py`. + +1. Configure the policy settings (`type="sac"`, `device`, etc.) +2. Set `dataset` to your cropped dataset +3. Configure environment settings with crop parameters +4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/19bb621a7d0a31c20cd3cc08b1dbab68d3031454/lerobot/common/policies/sac/configuration_sac.py#L79). +5. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task. + +**Starting the Learner** + +First, start the learner server process: + +```bash +python lerobot/scripts/rl/learner.py --config_path lerobot/configs/train_config_hilserl_so100.json +``` + +The learner: +- Initializes the policy network +- Prepares replay buffers +- Opens a `gRPC` server to communicate with actors +- Processes transitions and updates the policy + +**Starting the Actor** + +In a separate terminal, start the actor process with the same configuration: + +```bash +python lerobot/scripts/rl/actor.py --config_path lerobot/configs/train_config_hilserl_so100.json +``` + +The actor: +- Connects to the learner via `gRPC` +- Initializes the environment +- Execute rollouts of the policy to collect experience +- Sends transitions to the learner +- Receives updated policy parameters + +**Training Flow** + +The training proceeds automatically: + +1. The actor executes the policy in the environment +2. Transitions are collected and sent to the learner +3. The learner updates the policy based on these transitions +4. Updated policy parameters are sent back to the actor +5. The process continues until the specified step limit is reached + +**Human in the Loop** + +- The key to learning efficiently is to have human interventions to provide corrective feedback and completing the task to aide the policy learning and exploration. +- To perform human interventions, you can press the upper right trigger button on the gamepad (or the `space` key on the keyboard). This will pause the policy actions and allow you to take over. +- A successful experiment is one where the human has to intervene at the start but then reduces the amount of interventions as the policy improves. You can monitor the intervention rate in the `wandb` dashboard. + +

+ Figure shows the control mappings on a Logitech gamepad. +

+ +

Example showing how human interventions help guide policy learning over time

+ +- The figure shows the plot of the episodic reward over interaction step. The figure shows the effect of human interventions on the policy learning. +- The orange curve is an experiment without any human interventions. While the pink and blue curves are experiments with human interventions. +- We can observe that the number of steps where the policy starts achieving the maximum reward is cut by a quarter when human interventions are present. + +**Monitoring and Debugging** + +If you have `wandb.enable` set to `true` in your configuration, you can monitor training progress in real-time through the [Weights & Biases](https://wandb.ai/site/) dashboard. + +### Guide to Human Interventions +The learning process is very sensitive to the intervention strategy. It will takes a few runs to understand how to intervene effectively. Some tips and hints: +- Allow the policy to explore for a few episodes at the start of training. +- Avoid intervening for long periods of time. Try to intervene in situation to correct the robot's behaviour when it goes off track. +- Once the policy starts achieving the task, even if its not perfect, you can limit your interventions to simple quick actions like a simple grasping commands. + +The ideal behaviour is that your intervention rate should drop gradually during training as shown in the figure below. + +

+ Intervention rate +

+ +

Plot of the intervention rate during a training run on a pick and lift cube task

+ +### Key hyperparameters to tune + +Some configuration values have a disproportionate impact on training stability and speed: + +- **`temperature_init`** (`policy.temperature_init`) – initial entropy temperature in SAC. Higher values encourage more exploration; lower values make the policy more deterministic early on. A good starting point is `1e-2`. We observed that setting it too high can make human interventions ineffective and slow down learning. +- **`policy_parameters_push_frequency`** (`policy.actor_learner_config.policy_parameters_push_frequency`) – interval in *seconds* between two weight pushes from the learner to the actor. The default is `4 s`. Decrease to **1-2 s** to provide fresher weights (at the cost of more network traffic); increase only if your connection is slow, as this will reduce sample efficiency. +- **`storage_device`** (`policy.storage_device`) – device on which the learner keeps the policy parameters. If you have spare GPU memory, set this to `"cuda"` (instead of the default `"cpu"`). Keeping the weights on-GPU removes CPU→GPU transfer overhead and can significantly increase the number of learner updates per second. + + +Congrats 🎉, you have finished this tutorial! + +> [!TIP] +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). + +Paper citation: +``` +@article{luo2024precise, + title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning}, + author={Luo, Jianlan and Xu, Charles and Wu, Jeffrey and Levine, Sergey}, + journal={arXiv preprint arXiv:2410.21845}, + year={2024} +} +``` diff --git a/docs/source/hilserl_sim.mdx b/docs/source/hilserl_sim.mdx new file mode 100644 index 000000000..3239ba91a --- /dev/null +++ b/docs/source/hilserl_sim.mdx @@ -0,0 +1,120 @@ +# Train RL in Simulation + +This guide explains how to use the `gym_hil` simulation environments as an alternative to real robots when working with the LeRobot framework for Human-In-the-Loop (HIL) reinforcement learning. + +`gym_hil` is a package that provides Gymnasium-compatible simulation environments specifically designed for Human-In-the-Loop reinforcement learning. These environments allow you to: + +- Train policies in simulation to test the RL stack before training on real robots + +- Collect demonstrations in sim using external devices like gamepads or keyboards +- Perform human interventions during policy learning + +Currently, the main environment is a Franka Panda robot simulation based on MuJoCo, with tasks like picking up a cube. + + +## Installation + +First, install the `gym_hil` package within the LeRobot environment: + +```bash +pip install -e ".[hilserl]" +``` + +## What do I need? + +- A gamepad or keyboard to control the robot +- A Nvidia GPU + + + +## Configuration + +To use `gym_hil` with LeRobot, you need to create a configuration file. An example is provided [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/gym_hil_env.json). Key configuration sections include: + +### Environment Type and Task + +```json +{ + "type": "hil", + "name": "franka_sim", + "task": "PandaPickCubeGamepad-v0", + "device": "cuda" +} +``` + +Available tasks: +- `PandaPickCubeBase-v0`: Basic environment +- `PandaPickCubeGamepad-v0`: With gamepad control +- `PandaPickCubeKeyboard-v0`: With keyboard control + +### Gym Wrappers Configuration + +```json +"wrapper": { + "gripper_penalty": -0.02, + "control_time_s": 15.0, + "use_gripper": true, + "fixed_reset_joint_positions": [0.0, 0.195, 0.0, -2.43, 0.0, 2.62, 0.785], + "end_effector_step_sizes": { + "x": 0.025, + "y": 0.025, + "z": 0.025 + }, + "control_mode": "gamepad" + } +``` + +Important parameters: +- `gripper_penalty`: Penalty for excessive gripper movement +- `use_gripper`: Whether to enable gripper control +- `end_effector_step_sizes`: Size of the steps in the x,y,z axes of the end-effector +- `control_mode`: Set to `"gamepad"` to use a gamepad controller + +## Running with HIL RL of LeRobot + +### Basic Usage + +To run the environment, set mode to null: + +```python +python lerobot/scripts/rl/gym_manipulator.py --config_path path/to/gym_hil_env.json +``` + +### Recording a Dataset + +To collect a dataset, set the mode to `record` whilst defining the repo_id and number of episodes to record: + +```python +python lerobot/scripts/rl/gym_manipulator.py --config_path path/to/gym_hil_env.json +``` + +### Training a Policy + +To train a policy, checkout the configuration example available [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/train_gym_hil_env.json) and run the actor and learner servers: + +```python +python lerobot/scripts/rl/actor.py --config_path path/to/train_gym_hil_env.json +``` + +In a different terminal, run the learner server: + +```python +python lerobot/scripts/rl/learner.py --config_path path/to/train_gym_hil_env.json +``` + +The simulation environment provides a safe and repeatable way to develop and test your Human-In-the-Loop reinforcement learning components before deploying to real robots. + +Congrats 🎉, you have finished this tutorial! + +> [!TIP] +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). + +Paper citation: +``` +@article{luo2024precise, + title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning}, + author={Luo, Jianlan and Xu, Charles and Wu, Jeffrey and Levine, Sergey}, + journal={arXiv preprint arXiv:2410.21845}, + year={2024} +} +``` diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx new file mode 100644 index 000000000..3dd9d80fb --- /dev/null +++ b/docs/source/il_robots.mdx @@ -0,0 +1,541 @@ +# Imitation Learning on Real-World Robots + +This tutorial will explain how to train a neural network to control a real robot autonomously. + +**You'll learn:** +1. How to record and visualize your dataset. +2. How to train a policy using your data and prepare it for evaluation. +3. How to evaluate your policy and visualize the results. + +By following these steps, you'll be able to replicate tasks, such as picking up a Lego block and placing it in a bin with a high success rate, as shown in the video below. + +
+Video: pickup lego block task + +
+ +
+ +
+ +This tutorial isn’t tied to a specific robot: we walk you through the commands and API snippets you can adapt for any supported platform. + +During data collection, you’ll use a “teloperation” device, such as a leader arm or keyboard to teleoperate the robot and record its motion trajectories. + +Once you’ve gathered enough trajectories, you’ll train a neural network to imitate these trajectories and deploy the trained model so your robot can perform the task autonomously. + +If you run into any issues at any point, jump into our [Discord community](https://discord.com/invite/s3KuuzsPFb) for support. + +## Set up and Calibrate + +If you haven't yet set up and calibrated your robot and teleop device, please do so by following the robot-specific tutorial. + +## Teleoperate + +In this example, we’ll demonstrate how to teleoperate the SO101 robot. For each command, we also provide a corresponding API example. + +Note that the `id` associated with a robot is used to store the calibration file. It's important to use the same `id` when teleoperating, recording, and evaluating when using the same setup. + + + +```bash +python -m lerobot.teleoperate \ + --robot.type=so101_follower \ + --robot.port=/dev/tty.usbmodem58760431541 \ + --robot.id=my_awesome_follower_arm \ + --teleop.type=so101_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ + --teleop.id=my_awesome_leader_arm +``` + + +```python +from lerobot.common.teleoperators.so101_leader import SO101LeaderConfig, SO101Leader +from lerobot.common.robots.so101_follower import SO101FollowerConfig, SO101Follower + +robot_config = SO101FollowerConfig( + port="/dev/tty.usbmodem58760431541", + id="my_red_robot_arm", +) + +teleop_config = SO101LeaderConfig( + port="/dev/tty.usbmodem58760431551", + id="my_blue_leader_arm", +) + +robot = SO101Follower(robot_config) +teleop_device = SO101Leader(teleop_config) +robot.connect() +teleop_device.connect() + +while True: + action = teleop_device.get_action() + robot.send_action(action) +``` + + + +The teleoperate command will automatically: +1. Identify any missing calibrations and initiate the calibration procedure. +2. Connect the robot and teleop device and start teleoperation. + +## Cameras + +To add cameras to your setup, follow this [Guide](./cameras#setup-cameras). + +## Teleoperate with cameras + +With `rerun`, you can teleoperate again while simultaneously visualizing the camera feeds and joint positions. In this example, we’re using the Koch arm. + + + +```bash +python -m lerobot.teleoperate \ + --robot.type=koch_follower \ + --robot.port=/dev/tty.usbmodem58760431541 \ + --robot.id=my_awesome_follower_arm \ + --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ + --teleop.type=koch_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ + --teleop.id=my_awesome_leader_arm \ + --display_data=true +``` + + +```python +from lerobot.common.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.common.teleoperators.koch_leader import KochLeaderConfig, KochLeader +from lerobot.common.robots.koch_follower import KochFollowerConfig, KochFollower + +camera_config = { + "front": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30) +} + +robot_config = KochFollowerConfig( + port="/dev/tty.usbmodem585A0076841", + id="my_red_robot_arm", + cameras=camera_config +) + +teleop_config = KochLeaderConfig( + port="/dev/tty.usbmodem58760431551", + id="my_blue_leader_arm", +) + +robot = KochFollower(robot_config) +teleop_device = KochLeader(teleop_config) +robot.connect() +teleop_device.connect() + +while True: + observation = robot.get_observation() + action = teleop_device.get_action() + robot.send_action(action) +``` + + + +## Record a dataset + +Once you're familiar with teleoperation, you can record your first dataset. + +We use the Hugging Face hub features for uploading your dataset. If you haven't previously used the Hub, make sure you can login via the cli using a write-access token, this token can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens). + +Add your token to the CLI by running this command: +```bash +huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential +``` + +Then store your Hugging Face repository name in a variable: +```bash +HF_USER=$(huggingface-cli whoami | head -n 1) +echo $HF_USER +``` + +Now you can record a dataset. To record 5 episodes and upload your dataset to the hub, adapt the code below for your robot and execute the command or API example. + + + +```bash +python -m lerobot.record \ + --robot.type=so101_follower \ + --robot.port=/dev/tty.usbmodem585A0076841 \ + --robot.id=my_awesome_follower_arm \ + --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ + --teleop.type=so101_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ + --teleop.id=my_awesome_leader_arm \ + --display_data=true \ + --dataset.repo_id=${HF_USER}/record-test \ + --dataset.num_episodes=5 \ + --dataset.single_task="Grab the black cube" +``` + + +```python +from lerobot.common.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.utils import hw_to_dataset_features +from lerobot.common.robots.so100_follower import SO100Follower, SO100FollowerConfig +from lerobot.common.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig +from lerobot.common.teleoperators.so100_leader.so100_leader import SO100Leader +from lerobot.common.utils.control_utils import init_keyboard_listener +from lerobot.common.utils.utils import log_say +from lerobot.common.utils.visualization_utils import _init_rerun +from lerobot.record import record_loop + +NUM_EPISODES = 5 +FPS = 30 +EPISODE_TIME_SEC = 60 +RESET_TIME_SEC = 10 +TASK_DESCRIPTION = "My task description" + +# Create the robot and teleoperator configurations +camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", cameras=camera_config +) +teleop_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm") + +# Initialize the robot and teleoperator +robot = SO100Follower(robot_config) +teleop = SO100Leader(teleop_config) + +# Configure the dataset features +action_features = hw_to_dataset_features(robot.action_features, "action") +obs_features = hw_to_dataset_features(robot.observation_features, "observation") +dataset_features = {**action_features, **obs_features} + +# Create the dataset +dataset = LeRobotDataset.create( + repo_id="/", + fps=FPS, + features=dataset_features, + robot_type=robot.name, + use_videos=True, + image_writer_threads=4, +) + +# Initialize the keyboard listener and rerun visualization +_, events = init_keyboard_listener() +_init_rerun(session_name="recording") + +# Connect the robot and teleoperator +robot.connect() +teleop.connect() + +episode_idx = 0 +while episode_idx < NUM_EPISODES and not events["stop_recording"]: + log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") + + record_loop( + robot=robot, + events=events, + fps=FPS, + teleop=teleop, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + ) + + # Reset the environment if not stopping or re-recording + if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]): + log_say("Reset the environment") + record_loop( + robot=robot, + events=events, + fps=FPS, + teleop=teleop, + control_time_s=RESET_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + ) + + if events["rerecord_episode"]: + log_say("Re-recording episode") + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + dataset.save_episode() + episode_idx += 1 + +# Clean up +log_say("Stop recording") +robot.disconnect() +teleop.disconnect() +dataset.push_to_hub() +``` + + + +#### Dataset upload +Locally, your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}`. At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/so101_test) that you can obtain by running: +```bash +echo https://huggingface.co/datasets/${HF_USER}/so101_test +``` +Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example). + +You can look for other LeRobot datasets on the hub by searching for `LeRobot` [tags](https://huggingface.co/datasets?other=LeRobot). + +#### Record function + +The `record` function provides a suite of tools for capturing and managing data during robot operation: + +##### 1. Data Storage +- Data is stored using the `LeRobotDataset` format and is stored on disk during recording. +- By default, the dataset is pushed to your Hugging Face page after recording. + - To disable uploading, use `--dataset.push_to_hub=False`. + +##### 2. Checkpointing and Resuming +- Checkpoints are automatically created during recording. +- If an issue occurs, you can resume by re-running the same command with `--resume=true`. +- To start recording from scratch, **manually delete** the dataset directory. + +##### 3. Recording Parameters +Set the flow of data recording using command-line arguments: +- `--dataset.episode_time_s=60` + Duration of each data recording episode (default: **60 seconds**). +- `--dataset.reset_time_s=60` + Duration for resetting the environment after each episode (default: **60 seconds**). +- `--dataset.num_episodes=50` + Total number of episodes to record (default: **50**). + +##### 4. Keyboard Controls During Recording +Control the data recording flow using keyboard shortcuts: +- Press **Right Arrow (`→`)**: Early stop the current episode or reset time and move to the next. +- Press **Left Arrow (`←`)**: Cancel the current episode and re-record it. +- Press **Escape (`ESC`)**: Immediately stop the session, encode videos, and upload the dataset. + +#### Tips for gathering data + +Once you're comfortable with data recording, you can create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings. Also make sure the object you are manipulating is visible on the camera's. A good rule of thumb is you should be able to do the task yourself by only looking at the camera images. + +In the following sections, you’ll train your neural network. After achieving reliable grasping performance, you can start introducing more variations during data collection, such as additional grasp locations, different grasping techniques, and altering camera positions. + +Avoid adding too much variation too quickly, as it may hinder your results. + +If you want to dive deeper into this important topic, you can check out the [blog post](https://huggingface.co/blog/lerobot-datasets#what-makes-a-good-dataset) we wrote on what makes a good dataset. + + +#### Troubleshooting: +- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux). + +## Visualize a dataset + +If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by: +```bash +echo ${HF_USER}/so101_test +``` + +## Replay an episode + +A useful feature is the `replay` function, which allows you to replay any episode that you've recorded or episodes from any dataset out there. This function helps you test the repeatability of your robot's actions and assess transferability across robots of the same model. + +You can replay the first episode on your robot with either the command below or with the API example: + + + +```bash +python -m lerobot.replay \ + --robot.type=so101_follower \ + --robot.port=/dev/tty.usbmodem58760431541 \ + --robot.id=my_awesome_follower_arm \ + --dataset.repo_id=${HF_USER}/record-test \ + --dataset.episode=0 # choose the episode you want to replay +``` + + +```python +import time + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.common.robots.so100_follower.so100_follower import SO100Follower +from lerobot.common.utils.robot_utils import busy_wait +from lerobot.common.utils.utils import log_say + +episode_idx = 0 + +robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm") + +robot = SO100Follower(robot_config) +robot.connect() + +dataset = LeRobotDataset("/", episodes=[episode_idx]) +actions = dataset.hf_dataset.select_columns("action") + +log_say(f"Replaying episode {episode_idx}") +for idx in range(dataset.num_frames): + t0 = time.perf_counter() + + action = { + name: float(actions[idx]["action"][i]) for i, name in enumerate(dataset.features["action"]["names"]) + } + robot.send_action(action) + + busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0)) + +robot.disconnect() +``` + + + +Your robot should replicate movements similar to those you recorded. For example, check out [this video](https://x.com/RemiCadene/status/1793654950905680090) where we use `replay` on a Aloha robot from [Trossen Robotics](https://www.trossenrobotics.com). + +## Train a policy + +To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: +```bash +python lerobot/scripts/train.py \ + --dataset.repo_id=${HF_USER}/so101_test \ + --policy.type=act \ + --output_dir=outputs/train/act_so101_test \ + --job_name=act_so101_test \ + --policy.device=cuda \ + --wandb.enable=true \ + --policy.repo_id=${HF_USER}/my_policy +``` + +Let's explain the command: +1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so101_test`. +2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. +4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon. +5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. + +Training should take several hours. You will find checkpoints in `outputs/train/act_so101_test/checkpoints`. + +To resume training from a checkpoint, below is an example command to resume from `last` checkpoint of the `act_so101_test` policy: +```bash +python lerobot/scripts/train.py \ + --config_path=outputs/train/act_so101_test/checkpoints/last/pretrained_model/train_config.json \ + --resume=true +``` + +If you do not want to push your model to the hub after training use `--policy.push_to_hub=false`. + +Additionally you can provide extra `tags` or specify a `license` for your model or make the model repo `private` by adding this: `--policy.private=true --policy.tags=\[ppo,rl\] --policy.license=mit` + +#### Train using Collab +If your local computer doesn't have a powerful GPU you could utilize Google Collab to train your model by following the [ACT training notebook](./notebooks#training-act). + +#### Upload policy checkpoints + +Once training is done, upload the latest checkpoint with: +```bash +huggingface-cli upload ${HF_USER}/act_so101_test \ + outputs/train/act_so101_test/checkpoints/last/pretrained_model +``` + +You can also upload intermediate checkpoints with: +```bash +CKPT=010000 +huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \ + outputs/train/act_so101_test/checkpoints/${CKPT}/pretrained_model +``` + +## Run inference and evaluate your policy + +You can use the `record` script from [`lerobot/record.py`](https://github.com/huggingface/lerobot/blob/main/lerobot/record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes: + + + +```bash +python -m lerobot.record \ + --robot.type=so100_follower \ + --robot.port=/dev/ttyACM1 \ + --robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \ + --robot.id=my_awesome_follower_arm \ + --display_data=false \ + --dataset.repo_id=${HF_USER}/eval_so100 \ + --dataset.single_task="Put lego brick into the transparent box" \ + # <- Teleop optional if you want to teleoperate in between episodes \ + # --teleop.type=so100_leader \ + # --teleop.port=/dev/ttyACM0 \ + # --teleop.id=my_awesome_leader_arm \ + --policy.path=${HF_USER}/my_policy +``` + + +```python +from lerobot.common.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.utils import hw_to_dataset_features +from lerobot.common.policies.act.modeling_act import ACTPolicy +from lerobot.common.robots.so100_follower.config_so100_follower import SO100FollowerConfig +from lerobot.common.robots.so100_follower.so100_follower import SO100Follower +from lerobot.common.utils.control_utils import init_keyboard_listener +from lerobot.common.utils.utils import log_say +from lerobot.common.utils.visualization_utils import _init_rerun +from lerobot.record import record_loop + +NUM_EPISODES = 5 +FPS = 30 +EPISODE_TIME_SEC = 60 +TASK_DESCRIPTION = "My task description" + +# Create the robot configuration +camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} +robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm", cameras=camera_config +) + +# Initialize the robot +robot = SO100Follower(robot_config) + +# Initialize the policy +policy = ACTPolicy.from_pretrained("/") + +# Configure the dataset features +action_features = hw_to_dataset_features(robot.action_features, "action") +obs_features = hw_to_dataset_features(robot.observation_features, "observation") +dataset_features = {**action_features, **obs_features} + +# Create the dataset +dataset = LeRobotDataset.create( + repo_id="/eval_", + fps=FPS, + features=dataset_features, + robot_type=robot.name, + use_videos=True, + image_writer_threads=4, +) + +# Initialize the keyboard listener and rerun visualization +_, events = init_keyboard_listener() +_init_rerun(session_name="recording") + +# Connect the robot +robot.connect() + +for episode_idx in range(NUM_EPISODES): + log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") + + # Run the policy inference loop + record_loop( + robot=robot, + events=events, + fps=FPS, + policy=policy, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=True, + ) + + dataset.save_episode() + +# Clean up +robot.disconnect() +dataset.push_to_hub() +``` + + + +As you can see, it's almost the same command as previously used to record your training dataset. Two things changed: +1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so101_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so101_test`). +2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so101_test`). diff --git a/docs/source/il_sim.mdx b/docs/source/il_sim.mdx new file mode 100644 index 000000000..625b2fc00 --- /dev/null +++ b/docs/source/il_sim.mdx @@ -0,0 +1,152 @@ +# Imitation Learning in Sim + +This tutorial will explain how to train a neural network to control a robot in simulation with imitation learning. + +**You'll learn:** +1. How to record a dataset in simulation with [gym-hil](https://github.com/huggingface/gym-hil) and visualize the dataset. +2. How to train a policy using your data. +3. How to evaluate your policy in simulation and visualize the results. + +For the simulation environment we use the same [repo](https://github.com/huggingface/gym-hil) that is also being used by the Human-In-the-Loop (HIL) reinforcement learning algorithm. +This environment is based on [MuJoCo](https://mujoco.org) and allows you to record datasets in LeRobotDataset format. +Teleoperation is easiest with a controller like the Logitech F710, but you can also use your keyboard if you are up for the challenge. + +## Installation + +First, install the `gym_hil` package within the LeRobot environment, go to your LeRobot folder and run this command: + +```bash +pip install -e ".[hilserl]" +``` + +## Teleoperate and Record a Dataset + +To use `gym_hil` with LeRobot, you need to use a configuration file. An example config file can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/env_config_gym_hil_il.json). + +To teleoperate and collect a dataset, we need to modify this config file and you should add your `repo_id` here: `"repo_id": "il_gym",` and `"num_episodes": 30,` and make sure you set `mode` to `record`, "mode": "record". + +If you do not have a Nvidia GPU also change `"device": "cuda"` parameter in the config file (for example to `mps` for MacOS). + +By default the config file assumes you use a controller. To use your keyboard please change the envoirment specified at `"task"` in the config file and set it to `"PandaPickCubeKeyboard-v0"`. + +Then we can run this command to start: + + + + +```bash +python lerobot/scripts/rl/gym_manipulator.py --config_path path/to/env_config_gym_hil_il.json +``` + + + + +```bash +mjpython lerobot/scripts/rl/gym_manipulator.py --config_path path/to/env_config_gym_hil_il.json +``` + + + + +Once rendered you can teleoperate the robot with the gamepad or keyboard, below you can find the gamepad/keyboard controls. + +Note that to teleoperate the robot you have to hold the "Human Take Over Pause Policy" Button `RB` to enable control! + +**Gamepad Controls** + +

+ Figure shows the control mappings on a Logitech gamepad. +

+

Gamepad button mapping for robot control and episode management

+ +**Keyboard controls** + +For keyboard controls use the `spacebar` to enable control and the following keys to move the robot: +```bash + Arrow keys: Move in X-Y plane + Shift and Shift_R: Move in Z axis + Right Ctrl and Left Ctrl: Open and close gripper + ESC: Exit +``` + +## Visualize a dataset + +If you uploaded your dataset to the hub you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id. + +

+ Figure shows the dataset visualizer +

+

Dataset visualizer

+ + +## Train a policy + +To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: +```bash +python lerobot/scripts/train.py \ + --dataset.repo_id=${HF_USER}/il_gym \ + --policy.type=act \ + --output_dir=outputs/train/il_sim_test \ + --job_name=il_sim_test \ + --policy.device=cuda \ + --wandb.enable=true +``` + +Let's explain the command: +1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/il_gym`. +2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. +4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon. +5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. + +Training should take several hours, 100k steps (which is the default) will take about 1h on Nvidia A100. You will find checkpoints in `outputs/train/il_sim_test/checkpoints`. + +#### Train using Collab +If your local computer doesn't have a powerful GPU you could utilize Google Collab to train your model by following the [ACT training notebook](./notebooks#training-act). + +#### Upload policy checkpoints + +Once training is done, upload the latest checkpoint with: +```bash +huggingface-cli upload ${HF_USER}/il_sim_test \ + outputs/train/il_sim_test/checkpoints/last/pretrained_model +``` + +You can also upload intermediate checkpoints with: +```bash +CKPT=010000 +huggingface-cli upload ${HF_USER}/il_sim_test${CKPT} \ + outputs/train/il_sim_test/checkpoints/${CKPT}/pretrained_model +``` + +## Evaluate your policy in Sim + +To evaluate your policy we have to use the config file that can be found [here](https://huggingface.co/datasets/aractingi/lerobot-example-config-files/blob/main/eval_config_gym_hil.json). + +Make sure to replace the `repo_id` with the dataset you trained on, for example `pepijn223/il_sim_dataset` and replace the `pretrained_policy_name_or_path` with your model id, for example `pepijn223/il_sim_model` + +Then you can run this command to visualize your trained policy + + + + +```bash +python lerobot/scripts/rl/eval_policy.py --config_path=path/to/eval_config_gym_hil.json +``` + + + + +```bash +mjpython lerobot/scripts/rl/eval_policy.py --config_path=path/to/eval_config_gym_hil.json +``` + + + + +> [!WARNING] +> While the main workflow of training ACT in simulation is straightforward, there is significant room for exploring how to set up the task, define the initial state of the environment, and determine the type of data required during collection to learn the most effective policy. If your trained policy doesn't perform well, investigate the quality of the dataset it was trained on using our visualizers, as well as the action values and various hyperparameters related to ACT and the simulation. + +Congrats 🎉, you have finished this tutorial. If you want to continue with using LeRobot in simulation follow this [Tutorial on reinforcement learning in sim with HIL-SERL](https://huggingface.co/docs/lerobot/hilserl_sim) + +> [!TIP] +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). diff --git a/docs/source/index.mdx b/docs/source/index.mdx new file mode 100644 index 000000000..b8ff56ea7 --- /dev/null +++ b/docs/source/index.mdx @@ -0,0 +1,19 @@ +
+ + HuggingFace Expert Acceleration Program + +
+ +# LeRobot + +**State-of-the-art machine learning for real-world robotics** + +🤗 LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier for entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models. + +🤗 LeRobot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning. + +🤗 LeRobot already provides a set of pretrained models, datasets with human collected demonstrations, and simulated environments so that everyone can get started. + +🤗 LeRobot hosts pretrained models and datasets on the LeRobot HuggingFace page. + +Join the LeRobot community on [Discord](https://discord.gg/s3KuuzsPFb) diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx new file mode 100644 index 000000000..51474d8f7 --- /dev/null +++ b/docs/source/installation.mdx @@ -0,0 +1,72 @@ +# Installation + +## Install LeRobot + +Currently only available from source. + +Download our source code: +```bash +git clone https://github.com/huggingface/lerobot.git +cd lerobot +``` + +Create a virtual environment with Python 3.10, using [`Miniconda`](https://docs.anaconda.com/miniconda/install/#quick-command-line-install) +```bash +conda create -y -n lerobot python=3.10 +``` + +Then activate your conda environment, you have to do this each time you open a shell to use lerobot: +```bash +conda activate lerobot +``` + +When using `miniconda`, install `ffmpeg` in your environment: +```bash +conda install ffmpeg -c conda-forge +``` + +> [!TIP] +> This usually installs `ffmpeg 7.X` for your platform compiled with the `libsvtav1` encoder. If `libsvtav1` is not supported (check supported encoders with `ffmpeg -encoders`), you can: +> - _[On any platform]_ Explicitly install `ffmpeg 7.X` using: +> ```bash +> conda install ffmpeg=7.1.1 -c conda-forge +> ``` +> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`. + +Install 🤗 LeRobot: +```bash +pip install -e . +``` + +### Troubleshooting +If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`. +To install these for linux run: +```bash +sudo apt-get install cmake build-essential python-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config +``` +For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg) + +## Optional dependencies + +LeRobot provides optional extras for specific functionalities. Multiple extras can be combined (e.g., `.[aloha,feetech]`). For all available extras, refer to `pyproject.toml`. + +### Simulations +Install environment packages: `aloha` ([gym-aloha](https://github.com/huggingface/gym-aloha)), `xarm` ([gym-xarm](https://github.com/huggingface/gym-xarm)), or `pusht` ([gym-pusht](https://github.com/huggingface/gym-pusht)) +Example: +```bash +pip install -e ".[aloha]" # or "[pusht]" for example +``` + +### Motor Control +For Koch v1.1 install the Dynamixel SDK, for SO100/SO101/Moss install the Feetech SDK. +```bash +pip install -e ".[feetech]" # or "[dynamixel]" for example +``` + +### Experiment Tracking +To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with +```bash +wandb login +``` + +You can now assemble your robot if it's not ready yet, look for your robot type on the left. Then follow the link below to use Lerobot with your robot. diff --git a/docs/source/integrate_hardware.mdx b/docs/source/integrate_hardware.mdx new file mode 100644 index 000000000..f7de1cece --- /dev/null +++ b/docs/source/integrate_hardware.mdx @@ -0,0 +1,318 @@ +# Bring Your Own Hardware + +This tutorial will explain how to integrate your own robot design into the LeRobot ecosystem and have it access all of our tools (data collection, control pipelines, policy training and inference). + +To that end, we provide the [`Robot`](https://github.com/huggingface/lerobot/blob/main/lerobot/common/robots/robot.py) base class in the LeRobot which specifies a standard interface for physical robot integration. Let's see how to implement it. + +## Prerequisites + +- Your own robot which exposes a communication interface (e.g. serial, CAN, TCP) +- A way to read sensor data and send motor commands programmatically, e.g. manufacturer's SDK or API, or your own protocol implementation. +- LeRobot installed in your environment. Follow our [Installation Guide](./installation). + +## Choose your motors + +If you're using Feetech or Dynamixel motors, LeRobot provides built-in bus interfaces: + +- [`FeetechMotorsBus`](https://github.com/huggingface/lerobot/blob/main/lerobot/common/motors/feetech/feetech.py) – for controlling Feetech servos +- [`DynamixelMotorsBus`](https://github.com/huggingface/lerobot/blob/main/lerobot/common/motors/dynamixel/dynamixel.py) – for controlling Dynamixel servos + +Please refer to the [`MotorsBus`](https://github.com/huggingface/lerobot/blob/main/lerobot/common/motors/motors_bus.py) abstract class to learn about its API. +For a good example of how it can be used, you can have a look at our own [SO101 follower implementation](https://github.com/huggingface/lerobot/blob/main/lerobot/common/robots/so101_follower/so101_follower.py) + +Use these if compatible. Otherwise, you'll need to find or write a Python interface (not covered in this tutorial): +- Find an existing SDK in Python (or use bindings to C/C++) +- Or implement a basic communication wrapper (e.g., via pyserial, socket, or CANopen) + +You're not alone—many community contributions use custom boards or firmware! + +For Feetech and Dynamixel, we currently support these servos: + - Feetech: + - STS & SMS series (protocol 0): `sts3215`, `sts3250`, `sm8512bl` + - SCS series (protocol 1): `scs0009` + - Dynamixel (protocol 2.0 only): `xl330-m077`, `xl330-m288`, `xl430-w250`, `xm430-w350`, `xm540-w270`, `xc430-w150` + +If you are using Feetech or Dynamixel servos that are not in this list, you can add those in the [Feetech table](https://github.com/huggingface/lerobot/blob/main/lerobot/common/motors/feetech/tables.py) or [Dynamixel table](https://github.com/huggingface/lerobot/blob/main/lerobot/common/motors/dynamixel/tables.py). Depending on the model, this will require you to add model-specific information. In most cases though, there shouldn't be a lot of additions to do. + +In the next sections, we'll use a `FeetechMotorsBus` as the motors interface for the examples. Replace it and adapt to your motors if necessary. + +## Step 1: Subclass the `Robot` Interface + +You’ll first need to specify the config class and a string identifier (`name`) for your robot. If your robot has special needs that you'd like to be able to change easily, it should go here (e.g. port/address, baudrate). + +Here, we'll add the port name and one camera by default for our robot: +```python +from dataclasses import dataclass, field + +from lerobot.common.cameras import CameraConfig +from lerobot.common.cameras.opencv import OpenCVCameraConfig +from lerobot.common.robots import RobotConfig + + +@RobotConfig.register_subclass("my_cool_robot") +@dataclass +class MyCoolRobotConfig(RobotConfig): + port: str + cameras: dict[str, CameraConfig] = field( + default_factory={ + "cam_1": OpenCVCameraConfig( + index_or_path=2, + fps=30, + width=480, + height=640, + ), + } + ) +``` + +Have a look at our [Cameras tutorial](./cameras) to understand how to detect and add your camera. + +Next, we'll create our actual robot class which inherits from `Robot`. This abstract class defines a contract you must follow for your robot to be usable with the rest of the LeRobot tools. + +Here we'll create a simple 5-DoF robot with one camera. It could be a simple arm but notice that the `Robot` abstract class does not assume anything on your robot's form factor. You can let you imagination run wild when designing new robots! + +```python +from lerobot.common.cameras import make_cameras_from_configs +from lerobot.common.motors import Motor, MotorNormMode +from lerobot.common.motors.feetech import FeetechMotorsBus +from lerobot.common.robots import Robot + +class MyCoolRobot(Robot): + config_class = MyCoolRobotConfig + name = "my_cool_robot" + + def __init__(self, config: MyCoolRobotConfig): + super().__init__(config) + self.bus = FeetechMotorsBus( + port=self.config.port, + motors={ + "joint_1": Motor(1, "sts3250", MotorNormMode.RANGE_M100_100), + "joint_2": Motor(2, "sts3215", MotorNormMode.RANGE_M100_100), + "joint_3": Motor(3, "sts3215", MotorNormMode.RANGE_M100_100), + "joint_4": Motor(4, "sts3215", MotorNormMode.RANGE_M100_100), + "joint_5": Motor(5, "sts3215", MotorNormMode.RANGE_M100_100), + }, + calibration=self.calibration, + ) + self.cameras = make_cameras_from_configs(config.cameras) +``` + +## Step 2: Define Observation and Action Features + +These two properties define the *interface contract* between your robot and tools that consume it (such as data collection or learning pipelines). + +> [!WARNING] +> Note that these properties must be callable even if the robot is not yet connected, so avoid relying on runtime hardware state to define them. + +### `observation_features` + +This property should return a dictionary describing the structure of sensor outputs from your robot. The keys match what `get_observation()` returns, and the values describe either the shape (for arrays/images) or the type (for simple values). + +Example for our 5-DoF arm with one camera: +```python +@property +def _motors_ft(self) -> dict[str, type]: + return { + "joint_1.pos": float, + "joint_2.pos": float, + "joint_3.pos": float, + "joint_4.pos": float, + "joint_5.pos": float, + } + +@property +def _cameras_ft(self) -> dict[str, tuple]: + return { + cam: (self.cameras[cam].height, self.cameras[cam].width, 3) for cam in self.cameras + } + +@property +def observation_features(self) -> dict: + return {**self._motors_ft, **self._cameras_ft} +``` +In this case, observations consist of a simple dict storing each motor's position and a camera image. + +### `action_features` + +This property describes the commands your robot expects via `send_action()`. Again, keys must match the expected input format, and values define the shape/type of each command. + +Here, we simply use the same joints proprioceptive features (`self._motors_ft`) as with `observation_features`: the action sent will simply the goal position for each motor. +```python +def action_features(self) -> dict: + return self._motors_ft +``` + +## Step 3: Handle Connection and Disconnection + +These methods should handle opening and closing communication with your hardware (e.g. serial ports, CAN interfaces, USB devices, cameras). + +### `is_connected` + +This property should simply reflect that communication with the robot's hardware is established. When this property is `True`, it should be possible to read and write to the hardware using `get_observation()` and `send_action()`. + +```python +@property +def is_connected(self) -> bool: + return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) +``` + +### `connect()` + +This method should establish communication with the hardware. Moreover, if your robot needs calibration and is not calibrated, it should start a calibration procedure by default. If your robot needs some specific configuration, this should also be called here. + +```python +def connect(self, calibrate: bool = True) -> None: + self.bus.connect() + if not self.is_calibrated and calibrate: + self.calibrate() + + for cam in self.cameras.values(): + cam.connect() + + self.configure() +``` + +### `disconnect()` + +This method should gracefully terminate communication with the hardware: free any related resources (threads or processes), close ports, etc. + +Here, we already handle this in our `MotorsBus` and `Camera` classes so we just need to call their own `disconnect()` methods: +```python +def disconnect(self) -> None: + self.bus.disconnect() + for cam in self.cameras.values(): + cam.disconnect() +``` + +## Step 4: Support Calibration and Configuration + +LeRobot supports saving and loading calibration data automatically. This is useful for joint offsets, zero positions, or sensor alignment. + +> Note that depending on your hardware, this may not apply. If that's the case, you can simply leave these methods as no-ops: +> ```python +> @property +> def is_calibrated(self) -> bool: +> return True +> +> def calibrate(self) -> None: +> pass +> ``` + +### `is_calibrated` + +This should reflect whether your robot has the required calibration loaded. + +```python +@property +def is_calibrated(self) -> bool: + return self.bus.is_calibrated +``` + +### `calibrate()` + +The goal of the calibration is twofold: + - Know the physical range of motion of each motors in order to only send commands within this range. + - Normalize raw motors positions to sensible continuous values (e.g. percentages, degrees) instead of arbitrary discrete value dependant on the specific motor used that will not replicate elsewhere. + +It should implement the logic for calibration (if relevant) and update the `self.calibration` dictionary. If you are using Feetech or Dynamixel motors, our bus interfaces already include methods to help with this. + +```python +def calibrate(self) -> None: + self.bus.disable_torque() + for motor in self.bus.motors: + self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value) + + input(f"Move {self} to the middle of its range of motion and press ENTER....") + homing_offsets = self.bus.set_half_turn_homings() + + print( + "Move all joints sequentially through their entire ranges " + "of motion.\nRecording positions. Press ENTER to stop..." + ) + range_mins, range_maxes = self.bus.record_ranges_of_motion() + + self.calibration = {} + for motor, m in self.bus.motors.items(): + self.calibration[motor] = MotorCalibration( + id=m.id, + drive_mode=0, + homing_offset=homing_offsets[motor], + range_min=range_mins[motor], + range_max=range_maxes[motor], + ) + + self.bus.write_calibration(self.calibration) + self._save_calibration() + print("Calibration saved to", self.calibration_fpath) +``` + +### `configure()` + +Use this to set up any configuration for your hardware (servos control modes, controller gains, etc.). This should usually be run at connection time and be idempotent. + +```python +def configure(self) -> None: + with self.bus.torque_disabled(): + self.bus.configure_motors() + for motor in self.bus.motors: + self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value) + self.bus.write("P_Coefficient", motor, 16) + self.bus.write("I_Coefficient", motor, 0) + self.bus.write("D_Coefficient", motor, 32) +``` + +## Step 5: Implement Sensors Reading and Action Sending + +These are the most important runtime functions: the core I/O loop. + +### `get_observation()` + +Returns a dictionary of sensor values from the robot. These typically include motor states, camera frames, various sensors, etc. In the LeRobot framework, these observations are what will be fed to a policy in order to predict the actions to take. The dictionary keys and structure must match `observation_features`. + +```python +def get_observation(self) -> dict[str, Any]: + if not self.is_connected: + raise ConnectionError(f"{self} is not connected.") + + # Read arm position + obs_dict = self.bus.sync_read("Present_Position") + obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()} + + # Capture images from cameras + for cam_key, cam in self.cameras.items(): + obs_dict[cam_key] = cam.async_read() + + return obs_dict +``` + +### `send_action()` + +Takes a dictionary that matches `action_features`, and sends it to your hardware. You can add safety limits (clipping, smoothing) and return what was actually sent. + +For simplicity, we won't be adding any modification of the actions in our example here. + +```python +def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + goal_pos = {key.removesuffix(".pos"): val for key, val in action.items()} + + # Send goal position to the arm + self.bus.sync_write("Goal_Position", goal_pos) + + return action +``` + +## Adding a Teleoperator + +For implementing teleoperation devices, we also provide a [`Teleoperator`](https://github.com/huggingface/lerobot/blob/main/lerobot/common/teleoperators/teleoperator.py) base class. This class is very similar to the `Robot` base class and also doesn't assume anything on form factor. + +The main differences are in the I/O functions: a teleoperator allows you to produce action via `get_action` and can receive feedback actions via `send_feedback`. Feedback could be anything controllable on the teleoperation device that could help the person controlling it understand the consequences of the actions sent. Think motion/force feedback on a leader arm, vibrations on a gamepad controller for example. To implement a teleoperator, you can follow this same tutorial and adapt it for these two methods. + +## Wrapping Up + +Once your robot class is complete, you can leverage the LeRobot ecosystem: + +- Control your robot with available teleoperators or integrate directly your teleoperating device +- Record training data and visualize it +- Integrate it into RL or imitation learning pipelines + +Don't hesitate to reach out to the community for help on our [Discord](https://discord.gg/s3KuuzsPFb) 🤗 diff --git a/docs/source/koch.mdx b/docs/source/koch.mdx new file mode 120000 index 000000000..b2399ae62 --- /dev/null +++ b/docs/source/koch.mdx @@ -0,0 +1 @@ +../../lerobot/common/robots/koch_follower/koch.mdx \ No newline at end of file diff --git a/docs/source/lekiwi.mdx b/docs/source/lekiwi.mdx new file mode 120000 index 000000000..e2b4ff552 --- /dev/null +++ b/docs/source/lekiwi.mdx @@ -0,0 +1 @@ +../../lerobot/common/robots/lekiwi/lekiwi.mdx \ No newline at end of file diff --git a/docs/source/notebooks.mdx b/docs/source/notebooks.mdx new file mode 100644 index 000000000..729b31a99 --- /dev/null +++ b/docs/source/notebooks.mdx @@ -0,0 +1,29 @@ +# 🤗 LeRobot Notebooks + +This repository contains example notebooks for using LeRobot. These notebooks demonstrate how to train policies on real or simulation datasets using standardized policies. + +--- + +### Training ACT + +[ACT](https://huggingface.co/papers/2304.13705) (Action Chunking Transformer) is a transformer-based policy architecture for imitation learning that processes robot states and camera inputs to generate smooth, chunked action sequences. + +We provide a ready-to-run Google Colab notebook to help you train ACT policies using datasets from the Hugging Face Hub, with optional logging to Weights & Biases. + +| Notebook | Colab | +|:---------|:------| +| [Train ACT with LeRobot](https://github.com/huggingface/notebooks/blob/main/lerobot/training-act.ipynb) | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/lerobot/training-act.ipynb) | + +Expected training time for 100k steps: ~1.5 hours on an NVIDIA A100 GPU with batch size of `64`. + +### Training SmolVLA + +[SmolVLA](https://huggingface.co/papers/2506.01844) is a small but efficient Vision-Language-Action model. It is compact in size with 450 M-parameter and is developed by Hugging Face. + +We provide a ready-to-run Google Colab notebook to help you train SmolVLA policies using datasets from the Hugging Face Hub, with optional logging to Weights & Biases. + +| Notebook | Colab | +| :-------------------------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| [Train SmolVLA with LeRobot](https://github.com/huggingface/notebooks/blob/main/lerobot/training-smolvla.ipynb) | [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/lerobot/training-smolvla.ipynb) | + +Expected training time for 20k steps: ~5 hours on an NVIDIA A100 GPU with batch size of `64`. diff --git a/docs/source/smolvla.mdx b/docs/source/smolvla.mdx new file mode 100644 index 000000000..1d6596f65 --- /dev/null +++ b/docs/source/smolvla.mdx @@ -0,0 +1,97 @@ +# Finetune SmolVLA + +SmolVLA is Hugging Face’s lightweight foundation model for robotics. Designed for easy fine-tuning on LeRobot datasets, it helps accelerate your development! + +

+ SmolVLA architecture. +
+ Figure 1. SmolVLA takes as input (i) multiple cameras views, (ii) the robot’s current sensorimotor state, and (iii) a natural language instruction, encoded into contextual features used to condition the action expert when generating an action chunk. +

+ +## Set Up Your Environment + +1. Install LeRobot by following our [Installation Guide](./installation). +2. Install SmolVLA dependencies by running: + + ```bash + pip install -e ".[smolvla]" + ``` + +## Collect a dataset + +SmolVLA is a base model, so fine-tuning on your own data is required for optimal performance in your setup. +We recommend recording ~50 episodes of your task as a starting point. Follow our guide to get started: [Recording a Dataset](https://huggingface.co/docs/lerobot/getting_started_real_world_robot#record-a-dataset) + + + +In your dataset, make sure to have enough demonstrations per each variation (e.g. the cube position on the table if it is cube pick-place task) you are introducing. + +We recommend checking out the dataset linked below for reference that was used in the [SmolVLA paper](https://huggingface.co/papers/2506.01844): + +🔗 [SVLA SO100 PickPlace](https://huggingface.co/spaces/lerobot/visualize_dataset?path=%2Flerobot%2Fsvla_so100_pickplace%2Fepisode_0) + +In this dataset, we recorded 50 episodes across 5 distinct cube positions. For each position, we collected 10 episodes of pick-and-place interactions. This structure, repeating each variation several times, helped the model generalize better. We tried similar dataset with 25 episodes, and it was not enough leading to a bad performance. So, the data quality and quantity is definitely a key. +After you have your dataset available on the Hub, you are good to go to use our finetuning script to adapt SmolVLA to your application. + + +## Finetune SmolVLA on your data + +Use [`smolvla_base`](https://hf.co/lerobot/smolvla_base), our pretrained 450M model, and fine-tune it on your data. +Training the model for 20k steps will roughly take ~4 hrs on a single A100 GPU. You should tune the number of steps based on performance and your use-case. + +If you don't have a gpu device, you can train using our notebook on [![Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/lerobot/training-smolvla.ipynb) + +Pass your dataset to the training script using `--dataset.repo_id`. If you want to test your installation, run the following command where we use one of the datasets we collected for the [SmolVLA Paper](https://huggingface.co/papers/2506.01844). + +```bash +cd lerobot && python lerobot/scripts/train.py \ + --policy.path=lerobot/smolvla_base \ + --dataset.repo_id=${HF_USER}/mydataset \ + --batch_size=64 \ + --steps=20000 \ + --output_dir=outputs/train/my_smolvla \ + --job_name=my_smolvla_training \ + --policy.device=cuda \ + --wandb.enable=true +``` + + +You can start with a small batch size and increase it incrementally, if the GPU allows it, as long as loading times remain short. + + +Fine-tuning is an art. For a complete overview of the options for finetuning, run + +```bash +python lerobot/scripts/train.py --help +``` + +

+ Comparison of SmolVLA across task variations. +
+ Figure 2: Comparison of SmolVLA across task variations. From left to right: (1) pick-place cube counting, (2) pick-place cube counting, (3) pick-place cube counting under perturbations, and (4) generalization on pick-and-place of the lego block with real-world SO101. +

+ + +## Evaluate the finetuned model and run it in real-time + +Similarly for when recording an episode, it is recommended that you are logged in to the HuggingFace Hub. You can follow the corresponding steps: [Record a dataset](./getting_started_real_world_robot#record-a-dataset). +Once you are logged in, you can run inference in your setup by doing: + +```bash +python -m lerobot.record \ + --robot.type=so101_follower \ + --robot.port=/dev/ttyACM0 \ # <- Use your port + --robot.id=my_blue_follower_arm \ # <- Use your robot id + --robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras + --dataset.single_task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording + --dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub + --dataset.episode_time_s=50 \ + --dataset.num_episodes=10 \ + # <- Teleop optional if you want to teleoperate in between episodes \ + # --teleop.type=so100_leader \ + # --teleop.port=/dev/ttyACM0 \ + # --teleop.id=my_red_leader_arm \ + --policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model +``` + +Depending on your evaluation setup, you can configure the duration and the number of episodes to record for your evaluation suite. diff --git a/docs/source/so100.mdx b/docs/source/so100.mdx new file mode 120000 index 000000000..65849e950 --- /dev/null +++ b/docs/source/so100.mdx @@ -0,0 +1 @@ +../../lerobot/common/robots/so100_follower/so100.mdx \ No newline at end of file diff --git a/docs/source/so101.mdx b/docs/source/so101.mdx new file mode 120000 index 000000000..dc4720c28 --- /dev/null +++ b/docs/source/so101.mdx @@ -0,0 +1 @@ +../../lerobot/common/robots/so101_follower/so101.mdx \ No newline at end of file diff --git a/examples/10_use_so100.md b/examples/10_use_so100.md deleted file mode 100644 index 9dbe974c1..000000000 --- a/examples/10_use_so100.md +++ /dev/null @@ -1,624 +0,0 @@ -# Using the [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot - -## Table of Contents - - - [A. Source the parts](#a-source-the-parts) - - [B. Install LeRobot](#b-install-lerobot) - - [C. Configure the Motors](#c-configure-the-motors) - - [D. Step-by-Step Assembly Instructions](#d-step-by-step-assembly-instructions) - - [E. Calibrate](#e-calibrate) - - [F. Teleoperate](#f-teleoperate) - - [G. Record a dataset](#g-record-a-dataset) - - [H. Visualize a dataset](#h-visualize-a-dataset) - - [I. Replay an episode](#i-replay-an-episode) - - [J. Train a policy](#j-train-a-policy) - - [K. Evaluate your policy](#k-evaluate-your-policy) - - [L. More Information](#l-more-information) - -## A. Source the parts - -Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts, -and advice if it's your first time printing or if you don't own a 3D printer. - -Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly. - -## B. Install LeRobot - -> [!TIP] -> We use the Command Prompt (cmd) quite a lot. If you are not comfortable using the cmd or want to brush up using the command line you can have a look here: [Command line crash course](https://developer.mozilla.org/en-US/docs/Learn_web_development/Getting_started/Environment_setup/Command_line) - -On your computer: - -#### 1. [Install Miniconda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install): - -#### 2. Restart shell -Copy paste in your shell: `source ~/.bashrc` or for Mac: `source ~/.bash_profile` or `source ~/.zshrc` if you're using zshell - -#### 3. Create and activate a fresh conda environment for lerobot - -
-Video install instructions - - - -
- -```bash -conda create -y -n lerobot python=3.10 -``` - -Then activate your conda environment (do this each time you open a shell to use lerobot!): -```bash -conda activate lerobot -``` - -#### 4. Clone LeRobot: -```bash -git clone https://github.com/huggingface/lerobot.git ~/lerobot -``` - -#### 5. Install ffmpeg in your environment: -When using `miniconda`, install `ffmpeg` in your environment: -```bash -conda install ffmpeg -c conda-forge -``` - -#### 6. Install LeRobot with dependencies for the feetech motors: -```bash -cd ~/lerobot && pip install -e ".[feetech]" -``` - -Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms :robot:. -Every time you now want to use LeRobot you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands. - -## C. Configure the motors - -> [!NOTE] -> Throughout this tutorial you will find videos on how to do the steps, the full video tutorial can be found here: [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). - -### 1. Find the USB ports associated to each arm - -Designate one bus servo adapter and 6 motors for your leader arm, and similarly the other bus servo adapter and 6 motors for the follower arm. It's convenient to label them and write on each motor if it's for the follower `F` or for the leader `L` and it's ID from 1 to 6 (F1...F6 and L1...L6). - -#### a. Run the script to find port - -
-Video finding port - - -
- -To find the port for each bus servo adapter, run the utility script: -```bash -python lerobot/scripts/find_motors_bus_port.py -``` - -#### b. Example outputs - -Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux): -``` -Finding all available ports for the MotorBus. -['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] -Remove the usb cable from your MotorsBus and press Enter when done. - -[...Disconnect leader arm and press Enter...] - -The port of this MotorsBus is /dev/tty.usbmodem575E0031751 -Reconnect the usb cable. -``` -Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux): -``` -Finding all available ports for the MotorBus. -['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] -Remove the usb cable from your MotorsBus and press Enter when done. - -[...Disconnect follower arm and press Enter...] - -The port of this MotorsBus is /dev/tty.usbmodem575E0032081 -Reconnect the usb cable. -``` - -#### c. Troubleshooting -On Linux, you might need to give access to the USB ports by running: -```bash -sudo chmod 666 /dev/ttyACM0 -sudo chmod 666 /dev/ttyACM1 -``` - -#### d. Update config file - -IMPORTANTLY: Now that you have your ports, update the **port** default values of [`SO100RobotConfig`](../lerobot/common/robot_devices/robots/configs.py). You will find something like: -```python -@RobotConfig.register_subclass("so100") -@dataclass -class So100RobotConfig(ManipulatorRobotConfig): - calibration_dir: str = ".cache/calibration/so100" - # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as - # the number of motors in your follower arms. - max_relative_target: int | None = None - - leader_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "main": FeetechMotorsBusConfig( - port="/dev/tty.usbmodem58760431091", <-- UPDATE HERE - motors={ - # name: (index, model) - "shoulder_pan": [1, "sts3215"], - "shoulder_lift": [2, "sts3215"], - "elbow_flex": [3, "sts3215"], - "wrist_flex": [4, "sts3215"], - "wrist_roll": [5, "sts3215"], - "gripper": [6, "sts3215"], - }, - ), - } - ) - - follower_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "main": FeetechMotorsBusConfig( - port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE - motors={ - # name: (index, model) - "shoulder_pan": [1, "sts3215"], - "shoulder_lift": [2, "sts3215"], - "elbow_flex": [3, "sts3215"], - "wrist_flex": [4, "sts3215"], - "wrist_roll": [5, "sts3215"], - "gripper": [6, "sts3215"], - }, - ), - } - ) -``` - -### 2. Assembling the Base -Let's begin with assembling the follower arm base - -#### a. Set IDs for all 12 motors - -
-Video configuring motor - - -
- -Plug your first motor F1 and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate. Replace the text after --port to the corresponding follower control board port and run this command in cmd: -```bash -python lerobot/scripts/configure_motor.py \ - --port /dev/tty.usbmodem58760432961 \ - --brand feetech \ - --model sts3215 \ - --baudrate 1000000 \ - --ID 1 -``` - -> [!NOTE] -> These motors are currently limited. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees). - -Then unplug your motor and plug the second motor and set its ID to 2. -```bash -python lerobot/scripts/configure_motor.py \ - --port /dev/tty.usbmodem58760432961 \ - --brand feetech \ - --model sts3215 \ - --baudrate 1000000 \ - --ID 2 -``` - -Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm. - - -#### b. Remove the gears of the 6 leader motors - -
-Video removing gears - - - -
- - -Follow the video for removing gears. You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm. - -## D. Step-by-Step Assembly Instructions - -**Step 1: Clean Parts** -- Remove all support material from the 3D-printed parts. ---- - -### Additional Guidance - -
-Video assembling arms - - - -
- -**Note:** -This video provides visual guidance for assembling the arms, but it doesn't specify when or how to do the wiring. Inserting the cables beforehand is much easier than doing it afterward. The first arm may take a bit more than 1 hour to assemble, but once you get used to it, you can assemble the second arm in under 1 hour. - ---- - -### First Motor - -**Step 2: Insert Wires** -- Insert two wires into the first motor. - - - -**Step 3: Install in Base** -- Place the first motor into the base. - - - -**Step 4: Secure Motor** -- Fasten the motor with 4 screws. Two from the bottom and two from top. - -**Step 5: Attach Motor Holder** -- Slide over the first motor holder and fasten it using two screws (one on each side). - - - -**Step 6: Attach Motor Horns** -- Install both motor horns, securing the top horn with a screw. Try not to move the motor position when attaching the motor horn, especially for the leader arms, where we removed the gears. - - -
- Video adding motor horn - -
- -**Step 7: Attach Shoulder Part** -- Route one wire to the back of the robot and the other to the left or in photo towards you (see photo). -- Attach the shoulder part. - - - -**Step 8: Secure Shoulder** -- Tighten the shoulder part with 4 screws on top and 4 on the bottom -*(access bottom holes by turning the shoulder).* - ---- - -### Second Motor Assembly - -**Step 9: Install Motor 2** -- Slide the second motor in from the top and link the wire from motor 1 to motor 2. - - - -**Step 10: Attach Shoulder Holder** -- Add the shoulder motor holder. -- Ensure the wire from motor 1 to motor 2 goes behind the holder while the other wire is routed upward (see photo). -- This part can be tight to assemble, you can use a workbench like the image or a similar setup to push the part around the motor. - -
- - - -
- -**Step 11: Secure Motor 2** -- Fasten the second motor with 4 screws. - -**Step 12: Attach Motor Horn** -- Attach both motor horns to motor 2, again use the horn screw. - -**Step 13: Attach Base** -- Install the base attachment using 2 screws. - - - -**Step 14: Attach Upper Arm** -- Attach the upper arm with 4 screws on each side. - - - ---- - -### Third Motor Assembly - -**Step 15: Install Motor 3** -- Route the motor cable from motor 2 through the cable holder to motor 3, then secure motor 3 with 4 screws. - -**Step 16: Attach Motor Horn** -- Attach both motor horns to motor 3 and secure one again with a horn screw. - - - -**Step 17: Attach Forearm** -- Connect the forearm to motor 3 using 4 screws on each side. - - - ---- - -### Fourth Motor Assembly - -**Step 18: Install Motor 4** -- Slide in motor 4, attach the cable from motor 3, and secure the cable in its holder with a screw. - -
- - -
- -**Step 19: Attach Motor Holder 4** -- Install the fourth motor holder (a tight fit). Ensure one wire is routed upward and the wire from motor 3 is routed downward (see photo). - - - -**Step 20: Secure Motor 4 & Attach Horn** -- Fasten motor 4 with 4 screws and attach its motor horns, use for one a horn screw. - - - ---- - -### Wrist Assembly - -**Step 21: Install Motor 5** -- Insert motor 5 into the wrist holder and secure it with 2 front screws. - - - -**Step 22: Attach Wrist** -- Connect the wire from motor 4 to motor 5. And already insert the other wire for the gripper. -- Secure the wrist to motor 4 using 4 screws on both sides. - - - -**Step 23: Attach Wrist Horn** -- Install only one motor horn on the wrist motor and secure it with a horn screw. - - - ---- - -### Follower Configuration - -**Step 24: Attach Gripper** -- Attach the gripper to motor 5. - - - -**Step 25: Install Gripper Motor** -- Insert the gripper motor, connect the motor wire from motor 5 to motor 6, and secure it with 3 screws on each side. - - - -**Step 26: Attach Gripper Horn & Claw** -- Attach the motor horns and again use a horn screw. -- Install the gripper claw and secure it with 4 screws on both sides. - - - -**Step 27: Mount Controller** -- Attach the motor controller on the back. - -
- - -
- -*Assembly complete – proceed to Leader arm assembly.* - ---- - -### Leader Configuration - -For the leader configuration, perform **Steps 1–23**. Make sure that you removed the motor gears from the motors. - -**Step 24: Attach Leader Holder** -- Mount the leader holder onto the wrist and secure it with a screw. - - - -**Step 25: Attach Handle** -- Attach the handle to motor 5 using 4 screws. - - - -**Step 26: Install Gripper Motor** -- Insert the gripper motor, secure it with 3 screws on each side, attach a motor horn using a horn screw, and connect the motor wire. - - - -**Step 27: Attach Trigger** -- Attach the follower trigger with 4 screws. - - - -**Step 28: Mount Controller** -- Attach the motor controller on the back. - -
- - -
- -*Assembly complete – proceed to calibration.* - - -## E. Calibrate - -Next, you'll need to calibrate your SO-100 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one SO-100 robot to work on another. - -#### a. Manual calibration of follower arm - -> [!IMPORTANT] -> Contrarily to step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the auto calibration, we will actually do manual calibration of follower for now. - -You will need to move the follower arm to these positions sequentially: - -| 1. Zero position | 2. Rotated position | 3. Rest position | -| ------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| SO-100 follower arm zero position | SO-100 follower arm rotated position | SO-100 follower arm rest position | - -Make sure both arms are connected and run this script to launch manual calibration: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=so100 \ - --robot.cameras='{}' \ - --control.type=calibrate \ - --control.arms='["main_follower"]' -``` - -#### b. Manual calibration of leader arm -Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially: - -| 1. Zero position | 2. Rotated position | 3. Rest position | -| ------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ | -| SO-100 leader arm zero position | SO-100 leader arm rotated position | SO-100 leader arm rest position | - -Run this script to launch manual calibration: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=so100 \ - --robot.cameras='{}' \ - --control.type=calibrate \ - --control.arms='["main_leader"]' -``` - -## F. Teleoperate - -**Simple teleop** -Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras): -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=so100 \ - --robot.cameras='{}' \ - --control.type=teleoperate -``` - - -#### a. Teleop with displaying cameras -Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset. - -> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`. - -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=so100 \ - --control.type=teleoperate -``` - -## G. Record a dataset - -Once you're familiar with teleoperation, you can record your first dataset with SO-100. - -If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens): -```bash -huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential -``` - -Store your Hugging Face repository name in a variable to run these commands: -```bash -HF_USER=$(huggingface-cli whoami | head -n 1) -echo $HF_USER -``` - -Record 2 episodes and upload your dataset to the hub: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=so100 \ - --control.type=record \ - --control.fps=30 \ - --control.single_task="Grasp a lego block and put it in the bin." \ - --control.repo_id=${HF_USER}/so100_test \ - --control.tags='["so100","tutorial"]' \ - --control.warmup_time_s=5 \ - --control.episode_time_s=30 \ - --control.reset_time_s=30 \ - --control.num_episodes=2 \ - --control.push_to_hub=true -``` - -Note: You can resume recording by adding `--control.resume=true`. - -## H. Visualize a dataset - -If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by: -```bash -echo ${HF_USER}/so100_test -``` - -If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with (a window can be opened in the browser `http://127.0.0.1:9090` with the visualization tool): -```bash -python lerobot/scripts/visualize_dataset_html.py \ - --repo-id ${HF_USER}/so100_test \ - --local-files-only 1 -``` - -## I. Replay an episode - -Now try to replay the first episode on your robot: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=so100 \ - --control.type=replay \ - --control.fps=30 \ - --control.repo_id=${HF_USER}/so100_test \ - --control.episode=0 -``` - -## J. Train a policy - -To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: -```bash -python lerobot/scripts/train.py \ - --dataset.repo_id=${HF_USER}/so100_test \ - --policy.type=act \ - --output_dir=outputs/train/act_so100_test \ - --job_name=act_so100_test \ - --policy.device=cuda \ - --wandb.enable=true -``` - -Let's explain it: -1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so100_test`. -2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. -4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon. -5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. - -Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`. - -To resume training from a checkpoint, below is an example command to resume from `last` checkpoint of the `act_so100_test` policy: -```bash -python lerobot/scripts/train.py \ - --config_path=outputs/train/act_so100_test/checkpoints/last/pretrained_model/train_config.json \ - --resume=true -``` - -## K. Evaluate your policy - -You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=so100 \ - --control.type=record \ - --control.fps=30 \ - --control.single_task="Grasp a lego block and put it in the bin." \ - --control.repo_id=${HF_USER}/eval_act_so100_test \ - --control.tags='["tutorial"]' \ - --control.warmup_time_s=5 \ - --control.episode_time_s=30 \ - --control.reset_time_s=30 \ - --control.num_episodes=10 \ - --control.push_to_hub=true \ - --control.policy.path=outputs/train/act_so100_test/checkpoints/last/pretrained_model -``` - -As you can see, it's almost the same command as previously used to record your training dataset. Two things changed: -1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_so100_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_so100_test`). -2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_so100_test`). - -## L. More Information - -Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot. - -> [!TIP] -> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb) in the channel [`#so100-arm`](https://discord.com/channels/1216765309076115607/1237741463832363039). diff --git a/examples/11_use_lekiwi.md b/examples/11_use_lekiwi.md deleted file mode 100644 index 1be7cbc4a..000000000 --- a/examples/11_use_lekiwi.md +++ /dev/null @@ -1,597 +0,0 @@ -# Using the [LeKiwi](https://github.com/SIGRobotics-UIUC/LeKiwi) Robot with LeRobot - -## Table of Contents - - - [A. Source the parts](#a-source-the-parts) - - [B. Install software Pi](#b-install-software-on-pi) - - [C. Setup LeRobot laptop/pc](#c-install-lerobot-on-laptop) - - [D. Assemble the arms](#d-assembly) - - [E. Calibrate](#e-calibration) - - [F. Teleoperate](#f-teleoperate) - - [G. Record a dataset](#g-record-a-dataset) - - [H. Visualize a dataset](#h-visualize-a-dataset) - - [I. Replay an episode](#i-replay-an-episode) - - [J. Train a policy](#j-train-a-policy) - - [K. Evaluate your policy](#k-evaluate-your-policy) - -> [!TIP] -> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb) in the channel [`#mobile-so-100-arm`](https://discord.com/channels/1216765309076115607/1318390825528332371). - -## A. Source the parts - -Follow this [README](https://github.com/SIGRobotics-UIUC/LeKiwi). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts, and advice if it's your first time printing or if you don't own a 3D printer. - -Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly. - -### Wired version -If you have the **wired** LeKiwi version you can skip the installation of the Raspberry Pi and setting up SSH. You can also run all commands directly on your PC for both the LeKiwi scripts and the leader arm scripts for teleoperating. - -## B. Install software on Pi -Now we have to setup the remote PC that will run on the LeKiwi Robot. This is normally a Raspberry Pi, but can be any PC that can run on 5V and has enough usb ports (2 or more) for the cameras and motor control board. - -### Install OS -For setting up the Raspberry Pi and its SD-card see: [Setup PI](https://www.raspberrypi.com/documentation/computers/getting-started.html). Here is explained how to download the [Imager](https://www.raspberrypi.com/software/) to install Raspberry Pi OS or Ubuntu. - -### Setup SSH -After setting up your Pi, you should enable and setup [SSH](https://www.raspberrypi.com/news/coding-on-raspberry-pi-remotely-with-visual-studio-code/) (Secure Shell Protocol) so you can login into the Pi from your laptop without requiring a screen, keyboard and mouse in the Pi. A great tutorial on how to do this can be found [here](https://www.raspberrypi.com/documentation/computers/remote-access.html#ssh). Logging into your Pi can be done in your Command Prompt (cmd) or if you use VSCode you can use [this](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-ssh) extension. - -### Install LeRobot - -On your Raspberry Pi: - -#### 1. [Install Miniconda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install): - -#### 2. Restart shell -Copy paste in your shell: `source ~/.bashrc` or for Mac: `source ~/.bash_profile` or `source ~/.zshrc` if you're using zshell - -#### 3. Create and activate a fresh conda environment for lerobot - -
-Video install instructions - - - -
- -```bash -conda create -y -n lerobot python=3.10 -``` - -Then activate your conda environment (do this each time you open a shell to use lerobot!): -```bash -conda activate lerobot -``` - -#### 4. Clone LeRobot: -```bash -git clone https://github.com/huggingface/lerobot.git ~/lerobot -``` - -#### 5. Install ffmpeg in your environment: -When using `miniconda`, install `ffmpeg` in your environment: -```bash -conda install ffmpeg -c conda-forge -``` - -#### 6. Install LeRobot with dependencies for the feetech motors: -```bash -cd ~/lerobot && pip install -e ".[feetech]" -``` - -## C. Install LeRobot on laptop -If you already have install LeRobot on your laptop you can skip this step, otherwise please follow along as we do the same steps we did on the Pi. - -> [!TIP] -> We use the Command Prompt (cmd) quite a lot. If you are not comfortable using the cmd or want to brush up using the command line you can have a look here: [Command line crash course](https://developer.mozilla.org/en-US/docs/Learn_web_development/Getting_started/Environment_setup/Command_line) - -On your computer: - -#### 1. [Install Miniconda](https://docs.anaconda.com/miniconda/install/#quick-command-line-install): - -#### 2. Restart shell -Copy paste in your shell: `source ~/.bashrc` or for Mac: `source ~/.bash_profile` or `source ~/.zshrc` if you're using zshell - -#### 3. Create and activate a fresh conda environment for lerobot - -
-Video install instructions - - - -
- -```bash -conda create -y -n lerobot python=3.10 -``` - -Then activate your conda environment (do this each time you open a shell to use lerobot!): -```bash -conda activate lerobot -``` - -#### 4. Clone LeRobot: -```bash -git clone https://github.com/huggingface/lerobot.git ~/lerobot -``` - -#### 5. Install ffmpeg in your environment: -When using `miniconda`, install `ffmpeg` in your environment: -```bash -conda install ffmpeg -c conda-forge -``` - -#### 6. Install LeRobot with dependencies for the feetech motors: -```bash -cd ~/lerobot && pip install -e ".[feetech]" -``` - -Great :hugs:! You are now done installing LeRobot and we can begin assembling the SO100 arms and Mobile base :robot:. -Every time you now want to use LeRobot you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands. - -# D. Assembly - -First we will assemble the two SO100 arms. One to attach to the mobile base and one for teleoperation. Then we will assemble the mobile base. - -## SO100 Arms -### Configure motors -The instructions for configuring the motors can be found [Here](https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md#c-configure-the-motors) in step C of the SO100 tutorial. Besides the ID's for the arm motors we also need to set the motor ID's for the mobile base. These needs to be in a specific order to work. Below an image of the motor ID's and motor mounting positions for the mobile base. Note that we only use one Motor Control board on LeKiwi. This means the motor ID's for the wheels are 7, 8 and 9. - -Motor ID's for mobile robot - -### Assemble arms -[Assemble arms instruction](https://github.com/huggingface/lerobot/blob/main/examples/10_use_so100.md#d-assemble-the-arms) - -## Mobile base (LeKiwi) -[Assemble LeKiwi](https://github.com/SIGRobotics-UIUC/LeKiwi) - -### Update config -Both config files on the LeKiwi LeRobot and on the laptop should be the same. First we should find the Ip address of the Raspberry Pi of the mobile manipulator. This is the same Ip address used in SSH. We also need the usb port of the control board of the leader arm on the laptop and the port of the control board on LeKiwi. We can find these ports with the following script. - -#### a. Run the script to find port - -
-Video finding port - - -
- -To find the port for each bus servo adapter, run the utility script: -```bash -python lerobot/scripts/find_motors_bus_port.py -``` - -#### b. Example outputs - -Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux): -``` -Finding all available ports for the MotorBus. -['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] -Remove the usb cable from your DynamixelMotorsBus and press Enter when done. - -[...Disconnect leader arm and press Enter...] - -The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751 -Reconnect the usb cable. -``` -Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux): -``` -Finding all available ports for the MotorBus. -['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] -Remove the usb cable from your DynamixelMotorsBus and press Enter when done. - -[...Disconnect follower arm and press Enter...] - -The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081 -Reconnect the usb cable. -``` - -#### c. Troubleshooting -On Linux, you might need to give access to the USB ports by running: -```bash -sudo chmod 666 /dev/ttyACM0 -sudo chmod 666 /dev/ttyACM1 -``` - -#### d. Update config file - -IMPORTANTLY: Now that you have your ports of leader and follower arm and ip address of the mobile-so100, update the **ip** in Network configuration, **port** in leader_arms and **port** in lekiwi. In the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py) file. Where you will find something like: -```python -@RobotConfig.register_subclass("lekiwi") -@dataclass -class LeKiwiRobotConfig(RobotConfig): - # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as - # the number of motors in your follower arms. - max_relative_target: int | None = None - - # Network Configuration - ip: str = "172.17.133.91" - port: int = 5555 - video_port: int = 5556 - - cameras: dict[str, CameraConfig] = field( - default_factory=lambda: { - "mobile": OpenCVCameraConfig(camera_index="/dev/video0", fps=30, width=640, height=480), - "mobile2": OpenCVCameraConfig(camera_index="/dev/video2", fps=30, width=640, height=480), - } - ) - - calibration_dir: str = ".cache/calibration/lekiwi" - - leader_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "main": FeetechMotorsBusConfig( - port="/dev/tty.usbmodem585A0077581", - motors={ - # name: (index, model) - "shoulder_pan": [1, "sts3215"], - "shoulder_lift": [2, "sts3215"], - "elbow_flex": [3, "sts3215"], - "wrist_flex": [4, "sts3215"], - "wrist_roll": [5, "sts3215"], - "gripper": [6, "sts3215"], - }, - ), - } - ) - - follower_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "main": FeetechMotorsBusConfig( - port="/dev/ttyACM0", - motors={ - # name: (index, model) - "shoulder_pan": [1, "sts3215"], - "shoulder_lift": [2, "sts3215"], - "elbow_flex": [3, "sts3215"], - "wrist_flex": [4, "sts3215"], - "wrist_roll": [5, "sts3215"], - "gripper": [6, "sts3215"], - "left_wheel": (7, "sts3215"), - "back_wheel": (8, "sts3215"), - "right_wheel": (9, "sts3215"), - }, - ), - } - ) - - teleop_keys: dict[str, str] = field( - default_factory=lambda: { - # Movement - "forward": "w", - "backward": "s", - "left": "a", - "right": "d", - "rotate_left": "z", - "rotate_right": "x", - # Speed control - "speed_up": "r", - "speed_down": "f", - # quit teleop - "quit": "q", - } - ) - - mock: bool = False -``` - -## Wired version - -For the wired LeKiwi version your configured IP address should refer to your own laptop (127.0.0.1), because leader arm and LeKiwi are in this case connected to own laptop. Below and example configuration for this wired setup: -```python -@RobotConfig.register_subclass("lekiwi") -@dataclass -class LeKiwiRobotConfig(RobotConfig): - # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as - # the number of motors in your follower arms. - max_relative_target: int | None = None - - # Network Configuration - ip: str = "127.0.0.1" - port: int = 5555 - video_port: int = 5556 - - cameras: dict[str, CameraConfig] = field( - default_factory=lambda: { - "front": OpenCVCameraConfig( - camera_index=0, fps=30, width=640, height=480, rotation=90 - ), - "wrist": OpenCVCameraConfig( - camera_index=1, fps=30, width=640, height=480, rotation=180 - ), - } - ) - - calibration_dir: str = ".cache/calibration/lekiwi" - - leader_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "main": FeetechMotorsBusConfig( - port="/dev/tty.usbmodem585A0077581", - motors={ - # name: (index, model) - "shoulder_pan": [1, "sts3215"], - "shoulder_lift": [2, "sts3215"], - "elbow_flex": [3, "sts3215"], - "wrist_flex": [4, "sts3215"], - "wrist_roll": [5, "sts3215"], - "gripper": [6, "sts3215"], - }, - ), - } - ) - - follower_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "main": FeetechMotorsBusConfig( - port="/dev/tty.usbmodem58760431061", - motors={ - # name: (index, model) - "shoulder_pan": [1, "sts3215"], - "shoulder_lift": [2, "sts3215"], - "elbow_flex": [3, "sts3215"], - "wrist_flex": [4, "sts3215"], - "wrist_roll": [5, "sts3215"], - "gripper": [6, "sts3215"], - "left_wheel": (7, "sts3215"), - "back_wheel": (8, "sts3215"), - "right_wheel": (9, "sts3215"), - }, - ), - } - ) - - teleop_keys: dict[str, str] = field( - default_factory=lambda: { - # Movement - "forward": "w", - "backward": "s", - "left": "a", - "right": "d", - "rotate_left": "z", - "rotate_right": "x", - # Speed control - "speed_up": "r", - "speed_down": "f", - # quit teleop - "quit": "q", - } - ) - - mock: bool = False -``` - -# E. Calibration -Now we have to calibrate the leader arm and the follower arm. The wheel motors don't have to be calibrated. - - -### Calibrate follower arm (on mobile base) -> [!IMPORTANT] -> Contrarily to step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the auto calibration, we will actually do manual calibration of follower for now. - -You will need to move the follower arm to these positions sequentially: - -| 1. Zero position | 2. Rotated position | 3. Rest position | -| ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| SO-100 follower arm zero position | SO-100 follower arm rotated position | SO-100 follower arm rest position | - -Make sure the arm is connected to the Raspberry Pi and run this script (on the Raspberry Pi) to launch manual calibration: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=lekiwi \ - --robot.cameras='{}' \ - --control.type=calibrate \ - --control.arms='["main_follower"]' -``` - -### Wired version -If you have the **wired** LeKiwi version please run all commands including this calibration command on your laptop. - -### Calibrate leader arm -Then to calibrate the leader arm (which is attached to the laptop/pc). You will need to move the leader arm to these positions sequentially: - -| 1. Zero position | 2. Rotated position | 3. Rest position | -| ------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ | -| SO-100 leader arm zero position | SO-100 leader arm rotated position | SO-100 leader arm rest position | - -Run this script (on your laptop/pc) to launch manual calibration: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=lekiwi \ - --robot.cameras='{}' \ - --control.type=calibrate \ - --control.arms='["main_leader"]' -``` - -# F. Teleoperate - -> [!TIP] -> If you're using a Mac, you might need to give Terminal permission to access your keyboard. Go to System Preferences > Security & Privacy > Input Monitoring and check the box for Terminal. - -To teleoperate SSH into your Raspberry Pi, and run `conda activate lerobot` and this script: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=lekiwi \ - --control.type=remote_robot -``` - -Then on your laptop, also run `conda activate lerobot` and this script: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=lekiwi \ - --control.type=teleoperate \ - --control.fps=30 -``` - -> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`. For the `--control.type=remote_robot` you will also need to set `--control.viewer_ip` and `--control.viewer_port` - -You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below: -| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) | -| ---------- | ------------------ | ---------------------- | -| Fast | 0.4 | 90 | -| Medium | 0.25 | 60 | -| Slow | 0.1 | 30 | - - -| Key | Action | -| --- | -------------- | -| W | Move forward | -| A | Move left | -| S | Move backward | -| D | Move right | -| Z | Turn left | -| X | Turn right | -| R | Increase speed | -| F | Decrease speed | - -> [!TIP] -> If you use a different keyboard you can change the keys for each command in the [`LeKiwiRobotConfig`](../lerobot/common/robot_devices/robots/configs.py). - -### Wired version -If you have the **wired** LeKiwi version please run all commands including both these teleoperation commands on your laptop. - -## Troubleshoot communication - -If you are having trouble connecting to the Mobile SO100, follow these steps to diagnose and resolve the issue. - -### 1. Verify IP Address Configuration -Make sure that the correct ip for the Pi is set in the configuration file. To check the Raspberry Pi's IP address, run (on the Pi command line): -```bash -hostname -I -``` - -### 2. Check if Pi is reachable from laptop/pc -Try pinging the Raspberry Pi from your laptop: -```bach -ping -``` - -If the ping fails: -- Ensure the Pi is powered on and connected to the same network. -- Check if SSH is enabled on the Pi. - -### 3. Try SSH connection -If you can't SSH into the Pi, it might not be properly connected. Use: -```bash -ssh @ -``` -If you get a connection error: -- Ensure SSH is enabled on the Pi by running: - ```bash - sudo raspi-config - ``` - Then navigate to: **Interfacing Options -> SSH** and enable it. - -### 4. Same config file -Make sure the configuration file on both your laptop/pc and the Raspberry Pi is the same. - -# G. Record a dataset -Once you're familiar with teleoperation, you can record your first dataset with LeKiwi. - -To start the program on LeKiwi, SSH into your Raspberry Pi, and run `conda activate lerobot` and this script: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=lekiwi \ - --control.type=remote_robot -``` - -If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens): -```bash -huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential -``` - -Store your Hugging Face repository name in a variable to run these commands: -```bash -HF_USER=$(huggingface-cli whoami | head -n 1) -echo $HF_USER -``` -On your laptop then run this command to record 2 episodes and upload your dataset to the hub: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=lekiwi \ - --control.type=record \ - --control.fps=30 \ - --control.single_task="Grasp a lego block and put it in the bin." \ - --control.repo_id=${HF_USER}/lekiwi_test \ - --control.tags='["tutorial"]' \ - --control.warmup_time_s=5 \ - --control.episode_time_s=30 \ - --control.reset_time_s=30 \ - --control.num_episodes=2 \ - --control.push_to_hub=true -``` - -Note: You can resume recording by adding `--control.resume=true`. - -### Wired version -If you have the **wired** LeKiwi version please run all commands including both these record dataset commands on your laptop. - -# H. Visualize a dataset - -If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by: -```bash -echo ${HF_USER}/lekiwi_test -``` - -If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with (a window can be opened in the browser `http://127.0.0.1:9090` with the visualization tool): -```bash -python lerobot/scripts/visualize_dataset_html.py \ - --repo-id ${HF_USER}/lekiwi_test \ - --local-files-only 1 -``` - -# I. Replay an episode -Now try to replay the first episode on your robot: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=lekiwi \ - --control.type=replay \ - --control.fps=30 \ - --control.repo_id=${HF_USER}/lekiwi_test \ - --control.episode=0 -``` - -## J. Train a policy - -To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: -```bash -python lerobot/scripts/train.py \ - --dataset.repo_id=${HF_USER}/lekiwi_test \ - --policy.type=act \ - --output_dir=outputs/train/act_lekiwi_test \ - --job_name=act_lekiwi_test \ - --policy.device=cuda \ - --wandb.enable=true -``` - -Let's explain it: -1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/lekiwi_test`. -2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. -4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon. -5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. - -Training should take several hours. You will find checkpoints in `outputs/train/act_lekiwi_test/checkpoints`. - -## K. Evaluate your policy - -You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=lekiwi \ - --control.type=record \ - --control.fps=30 \ - --control.single_task="Drive to the red block and pick it up" \ - --control.repo_id=${HF_USER}/eval_act_lekiwi_test \ - --control.tags='["tutorial"]' \ - --control.warmup_time_s=5 \ - --control.episode_time_s=30 \ - --control.reset_time_s=30 \ - --control.num_episodes=10 \ - --control.push_to_hub=true \ - --control.policy.path=outputs/train/act_lekiwi_test/checkpoints/last/pretrained_model -``` - -As you can see, it's almost the same command as previously used to record your training dataset. Two things changed: -1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_lekiwi_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_lekiwi_test`). -2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_lekiwi_test`). diff --git a/examples/11_use_moss.md b/examples/11_use_moss.md deleted file mode 100644 index 1b6f23b9a..000000000 --- a/examples/11_use_moss.md +++ /dev/null @@ -1,337 +0,0 @@ -This tutorial explains how to use [Moss v1](https://github.com/jess-moss/moss-robot-arms) with LeRobot. - -## Source the parts - -Follow this [README](https://github.com/jess-moss/moss-robot-arms). It contains the bill of materials with link to source the parts, as well as the instructions to 3D print the parts and advice if it's your first time printing or if you don't own a 3D printer already. - -**Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly. - -## Install LeRobot - -On your computer: - -1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install): -```bash -mkdir -p ~/miniconda3 -wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh -bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 -rm ~/miniconda3/miniconda.sh -~/miniconda3/bin/conda init bash -``` - -2. Restart shell or `source ~/.bashrc` - -3. Create and activate a fresh conda environment for lerobot -```bash -conda create -y -n lerobot python=3.10 && conda activate lerobot -``` - -4. Clone LeRobot: -```bash -git clone https://github.com/huggingface/lerobot.git ~/lerobot -``` - -5. Install ffmpeg in your environment: -When using `miniconda`, install `ffmpeg` in your environment: -```bash -conda install ffmpeg -c conda-forge -``` - -6. Install LeRobot with dependencies for the feetech motors: -```bash -cd ~/lerobot && pip install -e ".[feetech]" -``` - -## Configure the motors - -Follow steps 1 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic) which illustrates the use of our scripts below. - -**Find USB ports associated to your arms** -To find the correct ports for each arm, run the utility script twice: -```bash -python lerobot/scripts/find_motors_bus_port.py -``` - -Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux): -``` -Finding all available ports for the MotorBus. -['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] -Remove the usb cable from your DynamixelMotorsBus and press Enter when done. - -[...Disconnect leader arm and press Enter...] - -The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751 -Reconnect the usb cable. -``` - -Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux): -``` -Finding all available ports for the MotorBus. -['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] -Remove the usb cable from your DynamixelMotorsBus and press Enter when done. - -[...Disconnect follower arm and press Enter...] - -The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081 -Reconnect the usb cable. -``` - -Troubleshooting: On Linux, you might need to give access to the USB ports by running: -```bash -sudo chmod 666 /dev/ttyACM0 -sudo chmod 666 /dev/ttyACM1 -``` - -#### Update config file - -IMPORTANTLY: Now that you have your ports, update the **port** default values of [`MossRobotConfig`](../lerobot/common/robot_devices/robots/configs.py). You will find something like: -```python -@RobotConfig.register_subclass("moss") -@dataclass -class MossRobotConfig(ManipulatorRobotConfig): - calibration_dir: str = ".cache/calibration/moss" - # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as - # the number of motors in your follower arms. - max_relative_target: int | None = None - - leader_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "main": FeetechMotorsBusConfig( - port="/dev/tty.usbmodem58760431091", <-- UPDATE HERE - motors={ - # name: (index, model) - "shoulder_pan": [1, "sts3215"], - "shoulder_lift": [2, "sts3215"], - "elbow_flex": [3, "sts3215"], - "wrist_flex": [4, "sts3215"], - "wrist_roll": [5, "sts3215"], - "gripper": [6, "sts3215"], - }, - ), - } - ) - - follower_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "main": FeetechMotorsBusConfig( - port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE - motors={ - # name: (index, model) - "shoulder_pan": [1, "sts3215"], - "shoulder_lift": [2, "sts3215"], - "elbow_flex": [3, "sts3215"], - "wrist_flex": [4, "sts3215"], - "wrist_roll": [5, "sts3215"], - "gripper": [6, "sts3215"], - }, - ), - } - ) -``` - -**Configure your motors** -Plug your first motor and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate: -```bash -python lerobot/scripts/configure_motor.py \ - --port /dev/tty.usbmodem58760432961 \ - --brand feetech \ - --model sts3215 \ - --baudrate 1000000 \ - --ID 1 -``` - -Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees). - -Then unplug your motor and plug the second motor and set its ID to 2. -```bash -python lerobot/scripts/configure_motor.py \ - --port /dev/tty.usbmodem58760432961 \ - --brand feetech \ - --model sts3215 \ - --baudrate 1000000 \ - --ID 2 -``` - -Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm. - -**Remove the gears of the 6 leader motors** -Follow step 2 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm. - -**Add motor horn to the motors** -Follow step 3 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic). For Moss v1, you need to align the holes on the motor horn to the motor spline to be approximately 3, 6, 9 and 12 o'clock. -Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated. - -## Assemble the arms - -Follow step 4 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm. - -## Calibrate - -Next, you'll need to calibrate your Moss v1 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one Moss v1 robot to work on another. - -**Manual calibration of follower arm** -/!\ Contrarily to step 6 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic) which illustrates the auto calibration, we will actually do manual calibration of follower for now. - -You will need to move the follower arm to these positions sequentially: - -| 1. Zero position | 2. Rotated position | 3. Rest position | -| ------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| Moss v1 follower arm zero position | Moss v1 follower arm rotated position | Moss v1 follower arm rest position | - -Make sure both arms are connected and run this script to launch manual calibration: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=moss \ - --robot.cameras='{}' \ - --control.type=calibrate \ - --control.arms='["main_follower"]' -``` - -**Manual calibration of leader arm** -Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=DA91NJOtMic) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially: - -| 1. Zero position | 2. Rotated position | 3. Rest position | -| ------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- | -| Moss v1 leader arm zero position | Moss v1 leader arm rotated position | Moss v1 leader arm rest position | - -Run this script to launch manual calibration: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=moss \ - --robot.cameras='{}' \ - --control.type=calibrate \ - --control.arms='["main_leader"]' -``` - -## Teleoperate - -**Simple teleop** -Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras): -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=moss \ - --robot.cameras='{}' \ - --control.type=teleoperate -``` - - -**Teleop with displaying cameras** -Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset. - -> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`. - -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=moss \ - --control.type=teleoperate -``` - -## Record a dataset - -Once you're familiar with teleoperation, you can record your first dataset with Moss v1. - -If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens): -```bash -huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential -``` - -Store your Hugging Face repository name in a variable to run these commands: -```bash -HF_USER=$(huggingface-cli whoami | head -n 1) -echo $HF_USER -``` - -Record 2 episodes and upload your dataset to the hub: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=moss \ - --control.type=record \ - --control.fps=30 \ - --control.single_task="Grasp a lego block and put it in the bin." \ - --control.repo_id=${HF_USER}/moss_test \ - --control.tags='["moss","tutorial"]' \ - --control.warmup_time_s=5 \ - --control.episode_time_s=30 \ - --control.reset_time_s=30 \ - --control.num_episodes=2 \ - --control.push_to_hub=true -``` - -Note: You can resume recording by adding `--control.resume=true`. - -## Visualize a dataset - -If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by: -```bash -echo ${HF_USER}/moss_test -``` - -If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with: -```bash -python lerobot/scripts/visualize_dataset_html.py \ - --repo-id ${HF_USER}/moss_test \ - --local-files-only 1 -``` - -## Replay an episode - -Now try to replay the first episode on your robot: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=moss \ - --control.type=replay \ - --control.fps=30 \ - --control.repo_id=${HF_USER}/moss_test \ - --control.episode=0 -``` - -## Train a policy - -To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: -```bash -python lerobot/scripts/train.py \ - --dataset.repo_id=${HF_USER}/moss_test \ - --policy.type=act \ - --output_dir=outputs/train/act_moss_test \ - --job_name=act_moss_test \ - --policy.device=cuda \ - --wandb.enable=true -``` - -Let's explain it: -1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/moss_test`. -2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. -4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon. -5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. - -Training should take several hours. You will find checkpoints in `outputs/train/act_moss_test/checkpoints`. - -## Evaluate your policy - -You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=moss \ - --control.type=record \ - --control.fps=30 \ - --control.single_task="Grasp a lego block and put it in the bin." \ - --control.repo_id=${HF_USER}/eval_act_moss_test \ - --control.tags='["tutorial"]' \ - --control.warmup_time_s=5 \ - --control.episode_time_s=30 \ - --control.reset_time_s=30 \ - --control.num_episodes=10 \ - --control.push_to_hub=true \ - --control.policy.path=outputs/train/act_moss_test/checkpoints/last/pretrained_model -``` - -As you can see, it's almost the same command as previously used to record your training dataset. Two things changed: -1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_moss_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_moss_test`). -2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_moss_test`). - -## More - -Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot. - -If you have any question or need help, please reach out on Discord in the channel [`#moss-arm`](https://discord.com/channels/1216765309076115607/1275374638985252925). diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py index 686069589..4e6154c2e 100644 --- a/examples/2_evaluate_pretrained_policy.py +++ b/examples/2_evaluate_pretrained_policy.py @@ -13,7 +13,7 @@ # limitations under the License. """ -This scripts demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local +This script demonstrates how to evaluate a pretrained policy from the HuggingFace Hub or from your local training outputs directory. In the latter case, you might want to run examples/3_train_policy.py first. It requires the installation of the 'gym_pusht' simulation environment. Install it by running: @@ -119,7 +119,7 @@ while not done: rewards.append(reward) frames.append(env.render()) - # The rollout is considered done when the success state is reach (i.e. terminated is True), + # The rollout is considered done when the success state is reached (i.e. terminated is True), # or the maximum number of iterations is reached (i.e. truncated is True) done = terminated | truncated | done step += 1 diff --git a/examples/4_train_policy_with_script.md b/examples/4_train_policy_with_script.md index b23d22713..0c11afe98 100644 --- a/examples/4_train_policy_with_script.md +++ b/examples/4_train_policy_with_script.md @@ -4,7 +4,7 @@ This tutorial will explain the training script, how to use it, and particularly ## The training script -LeRobot offers a training script at [`lerobot/scripts/train.py`](../../lerobot/scripts/train.py). At a high level it does the following: +LeRobot offers a training script at [`lerobot/scripts/train.py`](../lerobot/scripts/train.py). At a high level it does the following: - Initialize/load a configuration for the following steps using. - Instantiates a dataset. @@ -21,7 +21,7 @@ In the training script, the main function `train` expects a `TrainPipelineConfig def train(cfg: TrainPipelineConfig): ``` -You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../../lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option) +You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option) When running the script, inputs for the command line are parsed thanks to the `@parser.wrap()` decorator and an instance of this class is automatically generated. Under the hood, this is done with [Draccus](https://github.com/dlwh/draccus) which is a tool dedicated for this purpose. If you're familiar with Hydra, Draccus can similarly load configurations from config files (.json, .yaml) and also override their values through command line inputs. Unlike Hydra, these configurations are pre-defined in the code through dataclasses rather than being defined entirely in config files. This allows for more rigorous serialization/deserialization, typing, and to manipulate configuration as objects directly in the code and not as dictionaries or namespaces (which enables nice features in an IDE such as autocomplete, jump-to-def, etc.) @@ -50,7 +50,7 @@ By default, every field takes its default value specified in the dataclass. If a ## Specifying values from the CLI -Let's say that we want to train [Diffusion Policy](../../lerobot/common/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this: +Let's say that we want to train [Diffusion Policy](../lerobot/common/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this: ```bash python lerobot/scripts/train.py \ --dataset.repo_id=lerobot/pusht \ @@ -60,10 +60,10 @@ python lerobot/scripts/train.py \ Let's break this down: - To specify the dataset, we just need to specify its `repo_id` on the hub which is the only required argument in the `DatasetConfig`. The rest of the fields have default values and in this case we are fine with those so we can just add the option `--dataset.repo_id=lerobot/pusht`. -- To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/common/policies](../../lerobot/common/policies) -- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/common/envs/configs.py`](../../lerobot/common/envs/configs.py) +- To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/common/policies](../lerobot/common/policies) +- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/common/envs/configs.py`](../lerobot/common/envs/configs.py) -Let's see another example. Let's say you've been training [ACT](../../lerobot/common/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with: +Let's see another example. Let's say you've been training [ACT](../lerobot/common/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with: ```bash python lerobot/scripts/train.py \ --policy.type=act \ @@ -74,7 +74,7 @@ python lerobot/scripts/train.py \ > Notice we added `--output_dir` to explicitly tell where to write outputs from this run (checkpoints, training state, configs etc.). This is not mandatory and if you don't specify it, a default directory will be created from the current date and time, env.type and policy.type. This will typically look like `outputs/train/2025-01-24/16-10-05_aloha_act`. We now want to train a different policy for aloha on another task. We'll change the dataset and use [lerobot/aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) instead. Of course, we also need to change the task of the environment as well to match this other task. -Looking at the [`AlohaEnv`](../../lerobot/common/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using: +Looking at the [`AlohaEnv`](../lerobot/common/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using: ```bash python lerobot/scripts/train.py \ --policy.type=act \ diff --git a/examples/7_get_started_with_real_robot.md b/examples/7_get_started_with_real_robot.md deleted file mode 100644 index a31524bfb..000000000 --- a/examples/7_get_started_with_real_robot.md +++ /dev/null @@ -1,998 +0,0 @@ -# Getting Started with Real-World Robots - -This tutorial will guide you through the process of setting up and training a neural network to autonomously control a real robot. - -**What You'll Learn:** -1. How to order and assemble your robot. -2. How to connect, configure, and calibrate your robot. -3. How to record and visualize your dataset. -4. How to train a policy using your data and prepare it for evaluation. -5. How to evaluate your policy and visualize the results. - -By following these steps, you'll be able to replicate tasks like picking up a Lego block and placing it in a bin with a high success rate, as demonstrated in [this video](https://x.com/RemiCadene/status/1814680760592572934). - -This tutorial is specifically made for the affordable [Koch v1.1](https://github.com/jess-moss/koch-v1-1) robot, but it contains additional information to be easily adapted to various types of robots like [Aloha bimanual robot](https://aloha-2.github.io) by changing some configurations. The Koch v1.1 consists of a leader arm and a follower arm, each with 6 motors. It can work with one or several cameras to record the scene, which serve as visual sensors for the robot. - -During the data collection phase, you will control the follower arm by moving the leader arm. This process is known as "teleoperation." This technique is used to collect robot trajectories. Afterward, you'll train a neural network to imitate these trajectories and deploy the network to enable your robot to operate autonomously. - -If you encounter any issues at any step of the tutorial, feel free to seek help on [Discord](https://discord.com/invite/s3KuuzsPFb) or don't hesitate to iterate with us on the tutorial by creating issues or pull requests. Thanks! - -## 1. Order and Assemble your Koch v1.1 - -Follow the sourcing and assembling instructions provided on the [Koch v1.1 Github page](https://github.com/jess-moss/koch-v1-1). This will guide you through setting up both the follower and leader arms, as shown in the image below. - -
- Koch v1.1 leader and follower arms -
- -For a visual walkthrough of the assembly process, you can refer to [this video tutorial](https://youtu.be/8nQIg9BwwTk). - -## 2. Configure motors, calibrate arms, teleoperate your Koch v1.1 - -First, install the additional dependencies required for robots built with dynamixel motors like Koch v1.1 by running one of the following commands (make sure gcc is installed). - -Using `pip`: -```bash -pip install -e ".[dynamixel]" -``` - -Using `poetry`: -```bash -poetry sync --extras "dynamixel" -``` - -Using `uv`: -```bash -uv sync --extra "dynamixel" -``` - -You are now ready to plug the 5V power supply to the motor bus of the leader arm (the smaller one) since all its motors only require 5V. - -Then plug the 12V power supply to the motor bus of the follower arm. It has two motors that need 12V, and the rest will be powered with 5V through the voltage convertor. - -Finally, connect both arms to your computer via USB. Note that the USB doesn't provide any power, and both arms need to be plugged in with their associated power supply to be detected by your computer. - -Now you are ready to configure your motors for the first time, as detailed in the sections below. In the upcoming sections, you'll learn about our classes and functions by running some python code in an interactive session, or by copy-pasting it in a python file. - -If you have already configured your motors the first time, you can streamline the process by directly running the teleoperate script (which is detailed further in the tutorial): - -> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`. - -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=koch \ - --control.type=teleoperate -``` - -It will automatically: -1. Identify any missing calibrations and initiate the calibration procedure. -2. Connect the robot and start teleoperation. - -### a. Control your motors with DynamixelMotorsBus - -You can use the [`DynamixelMotorsBus`](../lerobot/common/robot_devices/motors/dynamixel.py) to communicate with the motors connected as a chain to the corresponding USB bus. This class leverages the Python [Dynamixel SDK](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20) to facilitate reading from and writing to the motors. - -**First Configuration of your motors** - -You will need to unplug each motor in turn and run a command the identify the motor. The motor will save its own identification, so you only need to do this once. Start by unplugging all of the motors. - -Do the Leader arm first, as all of its motors are of the same type. Plug in your first motor on your leader arm and run this script to set its ID to 1. -```bash -python lerobot/scripts/configure_motor.py \ - --port /dev/tty.usbmodem58760432961 \ - --brand dynamixel \ - --model xl330-m288 \ - --baudrate 1000000 \ - --ID 1 -``` - -Then unplug your first motor and plug the second motor and set its ID to 2. -```bash -python lerobot/scripts/configure_motor.py \ - --port /dev/tty.usbmodem58760432961 \ - --brand dynamixel \ - --model xl330-m288 \ - --baudrate 1000000 \ - --ID 2 -``` - -Redo the process for all your motors until ID 6. - -The process for the follower arm is almost the same, but the follower arm has two types of motors. For the first two motors, make sure you set the model to `xl430-w250`. _Important: configuring follower motors requires plugging and unplugging power. Make sure you use the 5V power for the XL330s and the 12V power for the XL430s!_ - -After all of your motors are configured properly, you're ready to plug them all together in a daisy-chain as shown in the original video. - -**Instantiate the DynamixelMotorsBus** - -To begin, create two instances of the [`DynamixelMotorsBus`](../lerobot/common/robot_devices/motors/dynamixel.py), one for each arm, using their corresponding USB ports (e.g. `DynamixelMotorsBus(port="/dev/tty.usbmodem575E0031751"`). - -To find the correct ports for each arm, run the utility script twice: -```bash -python lerobot/scripts/find_motors_bus_port.py -``` - -Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux): -``` -Finding all available ports for the MotorBus. -['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] -Remove the usb cable from your DynamixelMotorsBus and press Enter when done. - -[...Disconnect leader arm and press Enter...] - -The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751 -Reconnect the usb cable. -``` - -Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux): -``` -Finding all available ports for the MotorBus. -['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] -Remove the usb cable from your DynamixelMotorsBus and press Enter when done. - -[...Disconnect follower arm and press Enter...] - -The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081 -Reconnect the usb cable. -``` - -Troubleshooting: On Linux, you might need to give access to the USB ports by running this command with your ports: -```bash -sudo chmod 666 /dev/tty.usbmodem575E0032081 -sudo chmod 666 /dev/tty.usbmodem575E0031751 -``` - -*Listing and Configuring Motors* - -Next, you'll need to list the motors for each arm, including their name, index, and model. Initially, each motor is assigned the factory default index `1`. Since each motor requires a unique index to function correctly when connected in a chain on a common bus, you'll need to assign different indices. It's recommended to use an ascending index order, starting from `1` (e.g., `1, 2, 3, 4, 5, 6`). These indices will be saved in the persistent memory of each motor during the first connection. - -To assign indices to the motors, run this code in an interactive Python session. Replace the `port` values with the ones you identified earlier: -```python -from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig -from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus - -leader_config = DynamixelMotorsBusConfig( - port="/dev/tty.usbmodem575E0031751", - motors={ - # name: (index, model) - "shoulder_pan": (1, "xl330-m077"), - "shoulder_lift": (2, "xl330-m077"), - "elbow_flex": (3, "xl330-m077"), - "wrist_flex": (4, "xl330-m077"), - "wrist_roll": (5, "xl330-m077"), - "gripper": (6, "xl330-m077"), - }, -) - -follower_config = DynamixelMotorsBusConfig( - port="/dev/tty.usbmodem575E0032081", - motors={ - # name: (index, model) - "shoulder_pan": (1, "xl430-w250"), - "shoulder_lift": (2, "xl430-w250"), - "elbow_flex": (3, "xl330-m288"), - "wrist_flex": (4, "xl330-m288"), - "wrist_roll": (5, "xl330-m288"), - "gripper": (6, "xl330-m288"), - }, -) - -leader_arm = DynamixelMotorsBus(leader_config) -follower_arm = DynamixelMotorsBus(follower_config) -``` - -IMPORTANTLY: Now that you have your ports, update [`KochRobotConfig`](../lerobot/common/robot_devices/robots/configs.py). You will find something like: -```python -@RobotConfig.register_subclass("koch") -@dataclass -class KochRobotConfig(ManipulatorRobotConfig): - calibration_dir: str = ".cache/calibration/koch" - # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as - # the number of motors in your follower arms. - max_relative_target: int | None = None - - leader_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "main": DynamixelMotorsBusConfig( - port="/dev/tty.usbmodem585A0085511", <-- UPDATE HERE - motors={ - # name: (index, model) - "shoulder_pan": [1, "xl330-m077"], - "shoulder_lift": [2, "xl330-m077"], - "elbow_flex": [3, "xl330-m077"], - "wrist_flex": [4, "xl330-m077"], - "wrist_roll": [5, "xl330-m077"], - "gripper": [6, "xl330-m077"], - }, - ), - } - ) - - follower_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "main": DynamixelMotorsBusConfig( - port="/dev/tty.usbmodem585A0076891", <-- UPDATE HERE - motors={ - # name: (index, model) - "shoulder_pan": [1, "xl430-w250"], - "shoulder_lift": [2, "xl430-w250"], - "elbow_flex": [3, "xl330-m288"], - "wrist_flex": [4, "xl330-m288"], - "wrist_roll": [5, "xl330-m288"], - "gripper": [6, "xl330-m288"], - }, - ), - } - ) -``` - -**Connect and Configure your Motors** - -Before you can start using your motors, you'll need to configure them to ensure proper communication. When you first connect the motors, the [`DynamixelMotorsBus`](../lerobot/common/robot_devices/motors/dynamixel.py) automatically detects any mismatch between the current motor indices (factory set to `1`) and the specified indices (e.g., `1, 2, 3, 4, 5, 6`). This triggers a configuration procedure that requires you to unplug the power cord and motors, then reconnect each motor sequentially, starting from the one closest to the bus. - -For a visual guide, refer to the [video tutorial of the configuration procedure](https://youtu.be/U78QQ9wCdpY). - -To connect and configure the leader arm, run the following code in the same Python interactive session as earlier in the tutorial: -```python -leader_arm.connect() -``` - -When you connect the leader arm for the first time, you might see an output similar to this: -``` -Read failed due to communication error on port /dev/tty.usbmodem575E0032081 for group_key ID_shoulder_pan_shoulder_lift_elbow_flex_wrist_flex_wrist_roll_gripper: [TxRxResult] There is no status packet! - -/!\ A configuration issue has been detected with your motors: -If this is the first time you are using these motors, press enter to configure your motors... but before verify that all the cables are connected the proper way. If you find an issue, before making a modification, kill the python process, unplug the power cord to not damage the motors, rewire correctly, then plug the power again and relaunch the script. - -Motor indices detected: {9600: [1]} - -1. Unplug the power cord -2. Plug/unplug minimal number of cables to only have the first 1 motor(s) (['shoulder_pan']) connected. -3. Re-plug the power cord -Press Enter to continue... - -*Follow the procedure* - -Setting expected motor indices: [1, 2, 3, 4, 5, 6] -``` - -Once the leader arm is configured, repeat the process for the follower arm by running: -```python -follower_arm.connect() -``` - -Congratulations! Both arms are now properly configured and connected. You won't need to go through the configuration procedure again in the future. - -**Troubleshooting**: - -If the configuration process fails, you may need to do the configuration process via the Dynamixel Wizard. - -Known failure modes: -- Calling `arm.connect()` raises `OSError: No motor found, but one new motor expected. Verify power cord is plugged in and retry` on Ubuntu 22. - -Steps: -1. Visit https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_wizard2/#connect-dynamixel. -2. Follow the software installation instructions in section 3 of the web page. -3. Launch the software. -4. Configure the device scanning options in the menu under `Tools` > `Options` > `Scan`. Check only Protocol 2.0, select only the USB port identifier of interest, select all baudrates, set the ID range to `[0, 10]`. _While this step was not strictly necessary, it greatly speeds up scanning_. -5. For each motor in turn: - - Disconnect the power to the driver board. - - Connect **only** the motor of interest to the driver board, making sure to disconnect it from any other motors. - - Reconnect the power to the driver board. - - From the software menu select `Device` > `Scan` and let the scan run. A device should appear. - - If the device has an asterisk (*) near it, it means the firmware is indeed outdated. From the software menu, select `Tools` > `Firmware Update`. Follow the prompts. - - The main panel should have table with various parameters of the device (refer to the web page, section 5). Select the row with `ID`, and then set the desired ID on the bottom right panel by selecting and clicking `Save`. - - Just like you did with the ID, also set the `Baud Rate` to 1 Mbps. -6. Check everything has been done right: - - Rewire the arms in their final configuration and power both of them. - - Scan for devices. All 12 motors should appear. - - Select the motors one by one and move the arm. Check that the graphical indicator near the top right shows the movement. - -** There is a common issue with the Dynamixel XL430-W250 motors where the motors become undiscoverable after upgrading their firmware from Mac and Windows Dynamixel Wizard2 applications. When this occurs, it is required to do a firmware recovery (Select `DYNAMIXEL Firmware Recovery` and follow the prompts). There are two known workarounds to conduct this firmware reset: - 1) Install the Dynamixel Wizard on a linux machine and complete the firmware recovery - 2) Use the Dynamixel U2D2 in order to perform the reset with Windows or Mac. This U2D2 can be purchased [here](https://www.robotis.us/u2d2/). - For either solution, open DYNAMIXEL Wizard 2.0 and select the appropriate port. You will likely be unable to see the motor in the GUI at this time. Select `Firmware Recovery`, carefully choose the correct model, and wait for the process to complete. Finally, re-scan to confirm the firmware recovery was successful. - -**Read and Write with DynamixelMotorsBus** - -To get familiar with how `DynamixelMotorsBus` communicates with the motors, you can start by reading data from them. Copy past this code in the same interactive python session: -```python -leader_pos = leader_arm.read("Present_Position") -follower_pos = follower_arm.read("Present_Position") -print(leader_pos) -print(follower_pos) -``` - -Expected output might look like: -``` -array([2054, 523, 3071, 1831, 3049, 2441], dtype=int32) -array([2003, 1601, 56, 2152, 3101, 2283], dtype=int32) -``` - -Try moving the arms to various positions and observe how the values change. - -Now let's try to enable torque in the follower arm by copy pasting this code: -```python -from lerobot.common.robot_devices.motors.dynamixel import TorqueMode - -follower_arm.write("Torque_Enable", TorqueMode.ENABLED.value) -``` - -With torque enabled, the follower arm will be locked in its current position. Do not attempt to manually move the arm while torque is enabled, as this could damage the motors. - -Now, to get more familiar with reading and writing, let's move the arm programmatically copy pasting the following example code: -```python -# Get the current position -position = follower_arm.read("Present_Position") - -# Update first motor (shoulder_pan) position by +10 steps -position[0] += 10 -follower_arm.write("Goal_Position", position) - -# Update all motors position by -30 steps -position -= 30 -follower_arm.write("Goal_Position", position) - -# Update gripper by +30 steps -position[-1] += 30 -follower_arm.write("Goal_Position", position[-1], "gripper") -``` - -When you're done playing, you can try to disable the torque, but make sure you hold your robot so that it doesn't fall: -```python -follower_arm.write("Torque_Enable", TorqueMode.DISABLED.value) -``` - -Finally, disconnect the arms: -```python -leader_arm.disconnect() -follower_arm.disconnect() -``` - -Alternatively, you can unplug the power cord, which will automatically disable torque and disconnect the motors. - -*/!\ Warning*: These motors tend to overheat, especially under torque or if left plugged in for too long. Unplug after use. - -### b. Teleoperate your Koch v1.1 with ManipulatorRobot - -**Instantiate the ManipulatorRobot** - -Before you can teleoperate your robot, you need to instantiate the [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py) using the previously defined `leader_config` and `follower_config`. - -For the Koch v1.1 robot, we only have one leader, so we refer to it as `"main"` and define it as `leader_arms={"main": leader_config}`. We do the same for the follower arm. For other robots (like the Aloha), which may have two pairs of leader and follower arms, you would define them like this: `leader_arms={"left": left_leader_config, "right": right_leader_config},`. Same thing for the follower arms. - - -Run the following code to instantiate your manipulator robot: -```python -from lerobot.common.robot_devices.robots.configs import KochRobotConfig -from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot - -robot_config = KochRobotConfig( - leader_arms={"main": leader_config}, - follower_arms={"main": follower_config}, - cameras={}, # We don't use any camera for now -) -robot = ManipulatorRobot(robot_config) -``` - -The `KochRobotConfig` is used to set the associated settings and calibration process. For instance, we activate the torque of the gripper of the leader Koch v1.1 arm and position it at a 40 degree angle to use it as a trigger. - -For the [Aloha bimanual robot](https://aloha-2.github.io), we would use `AlohaRobotConfig` to set different settings such as a secondary ID for shadow joints (shoulder, elbow). Specific to Aloha, LeRobot comes with default calibration files stored in in `.cache/calibration/aloha_default`. Assuming the motors have been properly assembled, no manual calibration step is expected for Aloha. - -**Calibrate and Connect the ManipulatorRobot** - -Next, you'll need to calibrate your Koch robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one Koch robot to work on another. - -When you connect your robot for the first time, the [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py) will detect if the calibration file is missing and trigger the calibration procedure. During this process, you will be guided to move each arm to three different positions. - -Here are the positions you'll move the follower arm to: - -| 1. Zero position | 2. Rotated position | 3. Rest position | -| ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| Koch v1.1 follower arm zero position | Koch v1.1 follower arm rotated position | Koch v1.1 follower arm rest position | - -And here are the corresponding positions for the leader arm: - -| 1. Zero position | 2. Rotated position | 3. Rest position | -| ----------------------------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------- | -| Koch v1.1 leader arm zero position | Koch v1.1 leader arm rotated position | Koch v1.1 leader arm rest position | - -You can watch a [video tutorial of the calibration procedure](https://youtu.be/8drnU9uRY24) for more details. - -During calibration, we count the number of full 360-degree rotations your motors have made since they were first used. That's why we ask yo to move to this arbitrary "zero" position. We don't actually "set" the zero position, so you don't need to be accurate. After calculating these "offsets" to shift the motor values around 0, we need to assess the rotation direction of each motor, which might differ. That's why we ask you to rotate all motors to roughly 90 degrees, to measure if the values changed negatively or positively. - -Finally, the rest position ensures that the follower and leader arms are roughly aligned after calibration, preventing sudden movements that could damage the motors when starting teleoperation. - -Importantly, once calibrated, all Koch robots will move to the same positions (e.g. zero and rotated position) when commanded. - -Run the following code to calibrate and connect your robot: -```python -robot.connect() -``` - -The output will look like this: -``` -Connecting main follower arm -Connecting main leader arm - -Missing calibration file '.cache/calibration/koch/main_follower.json' -Running calibration of koch main follower... -Move arm to zero position -[...] -Move arm to rotated position -[...] -Move arm to rest position -[...] -Calibration is done! Saving calibration file '.cache/calibration/koch/main_follower.json' - -Missing calibration file '.cache/calibration/koch/main_leader.json' -Running calibration of koch main leader... -Move arm to zero position -[...] -Move arm to rotated position -[...] -Move arm to rest position -[...] -Calibration is done! Saving calibration file '.cache/calibration/koch/main_leader.json' -``` - -*Verifying Calibration* - -Once calibration is complete, you can check the positions of the leader and follower arms to ensure they match. If the calibration was successful, the positions should be very similar. - -Run this code to get the positions in degrees: -```python -leader_pos = robot.leader_arms["main"].read("Present_Position") -follower_pos = robot.follower_arms["main"].read("Present_Position") - -print(leader_pos) -print(follower_pos) -``` - -Example output: -``` -array([-0.43945312, 133.94531, 179.82422, -18.984375, -1.9335938, 34.541016], dtype=float32) -array([-0.58723712, 131.72314, 174.98743, -16.872612, 0.786213, 35.271973], dtype=float32) -``` - -These values are in degrees, which makes them easier to interpret and debug. The zero position used during calibration should roughly correspond to 0 degrees for each motor, and the rotated position should roughly correspond to 90 degrees for each motor. - -**Teleoperate your Koch v1.1** - -You can easily teleoperate your robot by reading the positions from the leader arm and sending them as goal positions to the follower arm. - -To teleoperate your robot for 30 seconds at a frequency of approximately 200Hz, run the following code: -```python -import tqdm -seconds = 30 -frequency = 200 -for _ in tqdm.tqdm(range(seconds*frequency)): - leader_pos = robot.leader_arms["main"].read("Present_Position") - robot.follower_arms["main"].write("Goal_Position", leader_pos) -``` - -*Using `teleop_step` for Teleoperation* - -Alternatively, you can teleoperate the robot using the `teleop_step` method from [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py). - -Run this code to teleoperate: -```python -for _ in tqdm.tqdm(range(seconds*frequency)): - robot.teleop_step() -``` - -*Recording data during Teleoperation* - -Teleoperation is particularly useful for recording data. You can use the `teleop_step(record_data=True)` to returns both the follower arm's position as `"observation.state"` and the leader arm's position as `"action"`. This function also converts the numpy arrays into PyTorch tensors. If you're working with a robot that has two leader and two follower arms (like the Aloha), the positions are concatenated. - -Run the following code to see how slowly moving the leader arm affects the observation and action: -```python -leader_pos = robot.leader_arms["main"].read("Present_Position") -follower_pos = robot.follower_arms["main"].read("Present_Position") -observation, action = robot.teleop_step(record_data=True) - -print(follower_pos) -print(observation) -print(leader_pos) -print(action) -``` - -Expected output: -``` -array([7.8223, 131.1328, 165.5859, -23.4668, -0.9668, 32.4316], dtype=float32) -{'observation.state': tensor([7.8223, 131.1328, 165.5859, -23.4668, -0.9668, 32.4316])} -array([3.4277, 134.1211, 179.8242, -18.5449, -1.5820, 34.7168], dtype=float32) -{'action': tensor([3.4277, 134.1211, 179.8242, -18.5449, -1.5820, 34.7168])} -``` - -*Asynchronous Frame Recording* - -Additionally, `teleop_step` can asynchronously record frames from multiple cameras and include them in the observation dictionary as `"observation.images.CAMERA_NAME"`. This feature will be covered in more detail in the next section. - -*Disconnecting the Robot* - -When you're finished, make sure to disconnect your robot by running: -```python -robot.disconnect() -``` - -Alternatively, you can unplug the power cord, which will also disable torque. - -*/!\ Warning*: These motors tend to overheat, especially under torque or if left plugged in for too long. Unplug after use. - -### c. Add your cameras with OpenCVCamera - -**(Optional) Use your phone as camera on Linux** - -If you want to use your phone as a camera on Linux, follow these steps to set up a virtual camera - -1. *Install `v4l2loopback-dkms` and `v4l-utils`*. Those packages are required to create virtual camera devices (`v4l2loopback`) and verify their settings with the `v4l2-ctl` utility from `v4l-utils`. Install them using: -```python -sudo apt install v4l2loopback-dkms v4l-utils -``` -2. *Install [DroidCam](https://droidcam.app) on your phone*. This app is available for both iOS and Android. -3. *Install [OBS Studio](https://obsproject.com)*. This software will help you manage the camera feed. Install it using [Flatpak](https://flatpak.org): -```python -flatpak install flathub com.obsproject.Studio -``` -4. *Install the DroidCam OBS plugin*. This plugin integrates DroidCam with OBS Studio. Install it with: -```python -flatpak install flathub com.obsproject.Studio.Plugin.DroidCam -``` -5. *Start OBS Studio*. Launch with: -```python -flatpak run com.obsproject.Studio -``` -6. *Add your phone as a source*. Follow the instructions [here](https://droidcam.app/obs/usage). Be sure to set the resolution to `640x480`. -7. *Adjust resolution settings*. In OBS Studio, go to `File > Settings > Video`. Change the `Base(Canvas) Resolution` and the `Output(Scaled) Resolution` to `640x480` by manually typing it in. -8. *Start virtual camera*. In OBS Studio, follow the instructions [here](https://obsproject.com/kb/virtual-camera-guide). -9. *Verify the virtual camera setup*. Use `v4l2-ctl` to list the devices: -```python -v4l2-ctl --list-devices -``` -You should see an entry like: -``` -VirtualCam (platform:v4l2loopback-000): -/dev/video1 -``` -10. *Check the camera resolution*. Use `v4l2-ctl` to ensure that the virtual camera output resolution is `640x480`. Change `/dev/video1` to the port of your virtual camera from the output of `v4l2-ctl --list-devices`. -```python -v4l2-ctl -d /dev/video1 --get-fmt-video -``` -You should see an entry like: -``` ->>> Format Video Capture: ->>> Width/Height : 640/480 ->>> Pixel Format : 'YUYV' (YUYV 4:2:2) -``` - -Troubleshooting: If the resolution is not correct you will have to delete the Virtual Camera port and try again as it cannot be changed. - -If everything is set up correctly, you can proceed with the rest of the tutorial. - -**(Optional) Use your iPhone as a camera on MacOS** - -To use your iPhone as a camera on macOS, enable the Continuity Camera feature: -- Ensure your Mac is running macOS 13 or later, and your iPhone is on iOS 16 or later. -- Sign in both devices with the same Apple ID. -- Connect your devices with a USB cable or turn on Wi-Fi and Bluetooth for a wireless connection. - -For more details, visit [Apple support](https://support.apple.com/en-gb/guide/mac-help/mchl77879b8a/mac). - -Your iPhone should be detected automatically when running the camera setup script in the next section. - -**Instantiate an OpenCVCamera** - -The [`OpenCVCamera`](../lerobot/common/robot_devices/cameras/opencv.py) class allows you to efficiently record frames from most cameras using the [`opencv2`](https://docs.opencv.org) library. For more details on compatibility, see [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html). - -To instantiate an [`OpenCVCamera`](../lerobot/common/robot_devices/cameras/opencv.py), you need a camera index (e.g. `OpenCVCamera(camera_index=0)`). When you only have one camera like a webcam of a laptop, the camera index is usually `0` but it might differ, and the camera index might change if you reboot your computer or re-plug your camera. This behavior depends on your operating system. - -To find the camera indices, run the following utility script, which will save a few frames from each detected camera: -```bash -python lerobot/common/robot_devices/cameras/opencv.py \ - --images-dir outputs/images_from_opencv_cameras -``` - -The output will look something like this if you have two cameras connected: -``` -Mac or Windows detected. Finding available camera indices through scanning all indices from 0 to 60 -[...] -Camera found at index 0 -Camera found at index 1 -[...] -Connecting cameras -OpenCVCamera(0, fps=30.0, width=1920.0, height=1080.0, color_mode=rgb) -OpenCVCamera(1, fps=24.0, width=1920.0, height=1080.0, color_mode=rgb) -Saving images to outputs/images_from_opencv_cameras -Frame: 0000 Latency (ms): 39.52 -[...] -Frame: 0046 Latency (ms): 40.07 -Images have been saved to outputs/images_from_opencv_cameras -``` - -Check the saved images in `outputs/images_from_opencv_cameras` to identify which camera index corresponds to which physical camera (e.g. `0` for `camera_00` or `1` for `camera_01`): -``` -camera_00_frame_000000.png -[...] -camera_00_frame_000047.png -camera_01_frame_000000.png -[...] -camera_01_frame_000047.png -``` - -Note: Some cameras may take a few seconds to warm up, and the first frame might be black or green. - -Finally, run this code to instantiate and connectyour camera: -```python -from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig -from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera - -config = OpenCVCameraConfig(camera_index=0) -camera = OpenCVCamera(config) -camera.connect() -color_image = camera.read() - -print(color_image.shape) -print(color_image.dtype) -``` - -Expected output for a laptop camera on MacBookPro: -``` -(1080, 1920, 3) -uint8 -``` - -Or like this if you followed our tutorial to set a virtual camera: -``` -(480, 640, 3) -uint8 -``` - -With certain camera, you can also specify additional parameters like frame rate, resolution, and color mode during instantiation. For instance: -```python -config = OpenCVCameraConfig(camera_index=0, fps=30, width=640, height=480) -``` - -If the provided arguments are not compatible with the camera, an exception will be raised. - -*Disconnecting the camera* - -When you're done using the camera, disconnect it by running: -```python -camera.disconnect() -``` - -**Instantiate your robot with cameras** - -Additionally, you can set up your robot to work with your cameras. - -Modify the following Python code with the appropriate camera names and configurations: -```python -robot = ManipulatorRobot( - KochRobotConfig( - leader_arms={"main": leader_arm}, - follower_arms={"main": follower_arm}, - calibration_dir=".cache/calibration/koch", - cameras={ - "laptop": OpenCVCameraConfig(0, fps=30, width=640, height=480), - "phone": OpenCVCameraConfig(1, fps=30, width=640, height=480), - }, - ) -) -robot.connect() -``` - -As a result, `teleop_step(record_data=True` will return a frame for each camera following the pytorch "channel first" convention but we keep images in `uint8` with pixels in range [0,255] to easily save them. - -Modify this code with the names of your cameras and run it: -```python -observation, action = robot.teleop_step(record_data=True) -print(observation["observation.images.laptop"].shape) -print(observation["observation.images.phone"].shape) -print(observation["observation.images.laptop"].min().item()) -print(observation["observation.images.laptop"].max().item()) -``` - -The output should look like this: -``` -torch.Size([3, 480, 640]) -torch.Size([3, 480, 640]) -0 -255 -``` - -### d. Use `control_robot.py` and our `teleoperate` function - -Instead of manually running the python code in a terminal window, you can use [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) to instantiate your robot by providing the robot configurations via command line and control your robot with various modes as explained next. - -Try running this code to teleoperate your robot (if you dont have a camera, keep reading): -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=koch \ - --control.type=teleoperate -``` - -You will see a lot of lines appearing like this one: -``` -INFO 2024-08-10 11:15:03 ol_robot.py:209 dt: 5.12 (195.1hz) dtRlead: 4.93 (203.0hz) dtWfoll: 0.19 (5239.0hz) -``` - -It contains -- `2024-08-10 11:15:03` which is the date and time of the call to the print function. -- `ol_robot.py:209` which is the end of the file name and the line number where the print function is called (`lerobot/scripts/control_robot.py` line `209`). -- `dt: 5.12 (195.1hz)` which is the "delta time" or the number of milliseconds spent between the previous call to `robot.teleop_step()` and the current one, associated with the frequency (5.12 ms equals 195.1 Hz) ; note that you can control the maximum frequency by adding fps as argument such as `--fps 30`. -- `dtRlead: 4.93 (203.0hz)` which is the number of milliseconds it took to read the position of the leader arm using `leader_arm.read("Present_Position")`. -- `dtWfoll: 0.22 (4446.9hz)` which is the number of milliseconds it took to set a new goal position for the follower arm using `follower_arm.write("Goal_position", leader_pos)` ; note that writing is done asynchronously so it takes less time than reading. - -Importantly: If you don't have any camera, you can remove them dynamically with this [draccus](https://github.com/dlwh/draccus) syntax `--robot.cameras='{}'`: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=koch \ - --robot.cameras='{}' \ - --control.type=teleoperate -``` - -We advise to create a new yaml file when the command becomes too long. - -## 3. Record your Dataset and Visualize it - -Using what you've learned previously, you can now easily record a dataset of states and actions for one episode. You can use `busy_wait` to control the speed of teleoperation and record at a fixed `fps` (frame per seconds). - -Try this code to record 30 seconds at 60 fps: -```python -import time -from lerobot.scripts.control_robot import busy_wait - -record_time_s = 30 -fps = 60 - -states = [] -actions = [] -for _ in range(record_time_s * fps): - start_time = time.perf_counter() - observation, action = robot.teleop_step(record_data=True) - - states.append(observation["observation.state"]) - actions.append(action["action"]) - - dt_s = time.perf_counter() - start_time - busy_wait(1 / fps - dt_s) - -# Note that observation and action are available in RAM, but -# you could potentially store them on disk with pickle/hdf5 or -# our optimized format `LeRobotDataset`. More on this next. -``` - -Importantly, many utilities are still missing. For instance, if you have cameras, you will need to save the images on disk to not go out of RAM, and to do so in threads to not slow down communication with your robot. Also, you will need to store your data in a format optimized for training and web sharing like [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py). More on this in the next section. - -### a. Use the `record` function - -You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) to achieve efficient data recording. It encompasses many recording utilities: -1. Frames from cameras are saved on disk in threads, and encoded into videos at the end of each episode recording. -2. Video streams from cameras are displayed in window so that you can verify them. -3. Data is stored with [`LeRobotDataset`](../lerobot/common/datasets/lerobot_dataset.py) format which is pushed to your Hugging Face page (unless `--control.push_to_hub=false` is provided). -4. Checkpoints are done during recording, so if any issue occurs, you can resume recording by re-running the same command again with `--control.resume=true`. You will need to manually delete the dataset directory if you want to start recording from scratch. -5. Set the flow of data recording using command line arguments: - - `--control.warmup_time_s=10` defines the number of seconds before starting data collection. It allows the robot devices to warmup and synchronize (10 seconds by default). - - `--control.episode_time_s=60` defines the number of seconds for data recording for each episode (60 seconds by default). - - `--control.reset_time_s=60` defines the number of seconds for resetting the environment after each episode (60 seconds by default). - - `--control.num_episodes=50` defines the number of episodes to record (50 by default). -6. Control the flow during data recording using keyboard keys: - - Press right arrow `->` at any time during episode recording to early stop and go to resetting. Same during resetting, to early stop and to go to the next episode recording. - - Press left arrow `<-` at any time during episode recording or resetting to early stop, cancel the current episode, and re-record it. - - Press escape `ESC` at any time during episode recording to end the session early and go straight to video encoding and dataset uploading. -7. Similarly to `teleoperate`, you can also use the command line to override anything. - -Before trying `record`, if you want to push your dataset to the hub, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens): -```bash -huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential -``` -Also, store your Hugging Face repository name in a variable (e.g. `cadene` or `lerobot`). For instance, run this to use your Hugging Face user name as repository: -```bash -HF_USER=$(huggingface-cli whoami | head -n 1) -echo $HF_USER -``` -If you don't want to push to hub, use `--control.push_to_hub=false`. - -Now run this to record 2 episodes: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=koch \ - --control.type=record \ - --control.single_task="Grasp a lego block and put it in the bin." \ - --control.fps=30 \ - --control.repo_id=${HF_USER}/koch_test \ - --control.tags='["tutorial"]' \ - --control.warmup_time_s=5 \ - --control.episode_time_s=30 \ - --control.reset_time_s=30 \ - --control.num_episodes=2 \ - --control.push_to_hub=true -``` - - -This will write your dataset locally to `~/.cache/huggingface/lerobot/{repo-id}` (e.g. `data/cadene/koch_test`) and push it on the hub at `https://huggingface.co/datasets/{HF_USER}/{repo-id}`. Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example). - -You can look for other LeRobot datasets on the hub by searching for `LeRobot` tags: https://huggingface.co/datasets?other=LeRobot - -You will see a lot of lines appearing like this one: -``` -INFO 2024-08-10 15:02:58 ol_robot.py:219 dt:33.34 (30.0hz) dtRlead: 5.06 (197.5hz) dtWfoll: 0.25 (3963.7hz) dtRfoll: 6.22 (160.7hz) dtRlaptop: 32.57 (30.7hz) dtRphone: 33.84 (29.5hz) -``` -It contains: -- `2024-08-10 15:02:58` which is the date and time of the call to the print function, -- `ol_robot.py:219` which is the end of the file name and the line number where the print function is called (`lerobot/scripts/control_robot.py` line `219`). -- `dt:33.34 (30.0hz)` which is the "delta time" or the number of milliseconds spent between the previous call to `robot.teleop_step(record_data=True)` and the current one, associated with the frequency (33.34 ms equals 30.0 Hz) ; note that we use `--fps 30` so we expect 30.0 Hz ; when a step takes more time, the line appears in yellow. -- `dtRlead: 5.06 (197.5hz)` which is the delta time of reading the present position of the leader arm. -- `dtWfoll: 0.25 (3963.7hz)` which is the delta time of writing the goal position on the follower arm ; writing is asynchronous so it takes less time than reading. -- `dtRfoll: 6.22 (160.7hz)` which is the delta time of reading the present position on the follower arm. -- `dtRlaptop:32.57 (30.7hz) ` which is the delta time of capturing an image from the laptop camera in the thread running asynchronously. -- `dtRphone:33.84 (29.5hz)` which is the delta time of capturing an image from the phone camera in the thread running asynchronously. - -Troubleshooting: -- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux). - -At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/koch_test) that you can obtain by running: -```bash -echo https://huggingface.co/datasets/${HF_USER}/koch_test -``` - -### b. Advice for recording dataset - -Once you're comfortable with data recording, it's time to create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings. - -In the following sections, you’ll train your neural network. After achieving reliable grasping performance, you can start introducing more variations during data collection, such as additional grasp locations, different grasping techniques, and altering camera positions. - -Avoid adding too much variation too quickly, as it may hinder your results. - -In the coming months, we plan to release a foundational model for robotics. We anticipate that fine-tuning this model will enhance generalization, reducing the need for strict consistency during data collection. - -### c. Visualize all episodes - -You can visualize your dataset by running: -```bash -python lerobot/scripts/visualize_dataset_html.py \ - --repo-id ${HF_USER}/koch_test -``` - -Note: You might need to add `--local-files-only 1` if your dataset was not uploaded to hugging face hub. - -This will launch a local web server that looks like this: -
- Koch v1.1 leader and follower arms -
- -### d. Replay episode on your robot with the `replay` function - -A useful feature of [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) is the `replay` function, which allows to replay on your robot any episode that you've recorded or episodes from any dataset out there. This function helps you test the repeatability of your robot's actions and assess transferability across robots of the same model. - -To replay the first episode of the dataset you just recorded, run the following command: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=koch \ - --control.type=replay \ - --control.fps=30 \ - --control.repo_id=${HF_USER}/koch_test \ - --control.episode=0 -``` - -Your robot should replicate movements similar to those you recorded. For example, check out [this video](https://x.com/RemiCadene/status/1793654950905680090) where we use `replay` on a Aloha robot from [Trossen Robotics](https://www.trossenrobotics.com). - -## 4. Train a policy on your data - -### a. Use the `train` script - -To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: -```bash -python lerobot/scripts/train.py \ - --dataset.repo_id=${HF_USER}/koch_test \ - --policy.type=act \ - --output_dir=outputs/train/act_koch_test \ - --job_name=act_koch_test \ - --policy.device=cuda \ - --wandb.enable=true -``` - -Let's explain it: -1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/koch_test`. -2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. -4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon. -5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. - -For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md) - -### b. (Optional) Upload policy checkpoints to the hub - -Once training is done, upload the latest checkpoint with: -```bash -huggingface-cli upload ${HF_USER}/act_koch_test \ - outputs/train/act_koch_test/checkpoints/last/pretrained_model -``` - -You can also upload intermediate checkpoints with: -```bash -CKPT=010000 -huggingface-cli upload ${HF_USER}/act_koch_test_${CKPT} \ - outputs/train/act_koch_test/checkpoints/${CKPT}/pretrained_model -``` - -## 5. Evaluate your policy - -Now that you have a policy checkpoint, you can easily control your robot with it using methods from [`ManipulatorRobot`](../lerobot/common/robot_devices/robots/manipulator.py) and the policy. - -Try this code for running inference for 60 seconds at 30 fps: -```python -from lerobot.common.policies.act.modeling_act import ACTPolicy - -inference_time_s = 60 -fps = 30 -device = "cuda" # TODO: On Mac, use "mps" or "cpu" - -ckpt_path = "outputs/train/act_koch_test/checkpoints/last/pretrained_model" -policy = ACTPolicy.from_pretrained(ckpt_path) -policy.to(device) - -for _ in range(inference_time_s * fps): - start_time = time.perf_counter() - - # Read the follower state and access the frames from the cameras - observation = robot.capture_observation() - - # Convert to pytorch format: channel first and float32 in [0,1] - # with batch dimension - for name in observation: - if "image" in name: - observation[name] = observation[name].type(torch.float32) / 255 - observation[name] = observation[name].permute(2, 0, 1).contiguous() - observation[name] = observation[name].unsqueeze(0) - observation[name] = observation[name].to(device) - - # Compute the next action with the policy - # based on the current observation - action = policy.select_action(observation) - # Remove batch dimension - action = action.squeeze(0) - # Move to cpu, if not already the case - action = action.to("cpu") - # Order the robot to move - robot.send_action(action) - - dt_s = time.perf_counter() - start_time - busy_wait(1 / fps - dt_s) -``` - -### a. Use our `record` function - -Ideally, when controlling your robot with your neural network, you would want to record evaluation episodes and to be able to visualize them later on, or even train on them like in Reinforcement Learning. This pretty much corresponds to recording a new dataset but with a neural network providing the actions instead of teleoperation. - -To this end, you can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=koch \ - --control.type=record \ - --control.fps=30 \ - --control.repo_id=${HF_USER}/eval_act_koch_test \ - --control.tags='["tutorial"]' \ - --control.warmup_time_s=5 \ - --control.episode_time_s=30 \ - --control.reset_time_s=30 \ - --control.num_episodes=10 \ - --control.push_to_hub=true \ - --control.policy.path=outputs/train/act_koch_test/checkpoints/last/pretrained_model -``` - -As you can see, it's almost the same command as previously used to record your training dataset. Two things changed: -1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_koch_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_koch_test`). -2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_koch_test`). - -### b. Visualize evaluation afterwards - -You can then visualize your evaluation dataset by running the same command as before but with the new inference dataset as argument: -```bash -python lerobot/scripts/visualize_dataset.py \ - --repo-id ${HF_USER}/eval_act_koch_test -``` - -## 6. Next step - -Join our [Discord](https://discord.com/invite/s3KuuzsPFb) to collaborate on data collection and help us train a fully open-source foundational models for robotics! diff --git a/examples/9_use_aloha.md b/examples/9_use_aloha.md deleted file mode 100644 index 77cff1611..000000000 --- a/examples/9_use_aloha.md +++ /dev/null @@ -1,182 +0,0 @@ -This tutorial explains how to use [Aloha and Aloha 2 stationary](https://www.trossenrobotics.com/aloha-stationary) with LeRobot. - -## Setup - -Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/aloha_docs/2.0/getting_started/stationary/hardware_setup.html) for setting up the hardware and plugging the 4 arms and 4 cameras to your computer. - - -## Install LeRobot - -On your computer: - -1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install): -```bash -mkdir -p ~/miniconda3 -wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh -bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 -rm ~/miniconda3/miniconda.sh -~/miniconda3/bin/conda init bash -``` - -2. Restart shell or `source ~/.bashrc` - -3. Create and activate a fresh conda environment for lerobot -```bash -conda create -y -n lerobot python=3.10 && conda activate lerobot -``` - -4. Clone LeRobot: -```bash -git clone https://github.com/huggingface/lerobot.git ~/lerobot -``` - -5. When using `miniconda`, install `ffmpeg` in your environment: -```bash -conda install ffmpeg -c conda-forge -``` - -6. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense): -```bash -cd ~/lerobot && pip install -e ".[dynamixel, intelrealsense]" -``` - -## Teleoperate - -**/!\ FOR SAFETY, READ THIS /!\** -Teleoperation consists in manually operating the leader arms to move the follower arms. Importantly: -1. Make sure your leader arms are in the same position as the follower arms, so that the follower arms don't move too fast to match the leader arms, -2. Our code assumes that your robot has been assembled following Trossen Robotics instructions. This allows us to skip calibration, as we use the pre-defined calibration files in `.cache/calibration/aloha_default`. If you replace a motor, make sure you follow the exact instructions from Trossen Robotics. - -By running the following code, you can start your first **SAFE** teleoperation: - -> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`. - -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=aloha \ - --robot.max_relative_target=5 \ - --control.type=teleoperate -``` - -By adding `--robot.max_relative_target=5`, we override the default value for `max_relative_target` defined in [`AlohaRobotConfig`](lerobot/common/robot_devices/robots/configs.py). It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot.max_relative_target=null` to the command line: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=aloha \ - --robot.max_relative_target=null \ - --control.type=teleoperate -``` - -## Record a dataset - -Once you're familiar with teleoperation, you can record your first dataset with Aloha. - -If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens): -```bash -huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential -``` - -Store your Hugging Face repository name in a variable to run these commands: -```bash -HF_USER=$(huggingface-cli whoami | head -n 1) -echo $HF_USER -``` - -Record 2 episodes and upload your dataset to the hub: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=aloha \ - --robot.max_relative_target=null \ - --control.type=record \ - --control.fps=30 \ - --control.single_task="Grasp a lego block and put it in the bin." \ - --control.repo_id=${HF_USER}/aloha_test \ - --control.tags='["tutorial"]' \ - --control.warmup_time_s=5 \ - --control.episode_time_s=30 \ - --control.reset_time_s=30 \ - --control.num_episodes=2 \ - --control.push_to_hub=true -``` - -## Visualize a dataset - -If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by: -```bash -echo ${HF_USER}/aloha_test -``` - -If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with: -```bash -python lerobot/scripts/visualize_dataset_html.py \ - --repo-id ${HF_USER}/aloha_test -``` - -## Replay an episode - -**/!\ FOR SAFETY, READ THIS /!\** -Replay consists in automatically replaying the sequence of actions (i.e. goal positions for your motors) recorded in a given dataset episode. Make sure the current initial position of your robot is similar to the one in your episode, so that your follower arms don't move too fast to go to the first goal positions. For safety, you might want to add `--robot.max_relative_target=5` to your command line as explained above. - -Now try to replay the first episode on your robot: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=aloha \ - --robot.max_relative_target=null \ - --control.type=replay \ - --control.fps=30 \ - --control.repo_id=${HF_USER}/aloha_test \ - --control.episode=0 -``` - -## Train a policy - -To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: -```bash -python lerobot/scripts/train.py \ - --dataset.repo_id=${HF_USER}/aloha_test \ - --policy.type=act \ - --output_dir=outputs/train/act_aloha_test \ - --job_name=act_aloha_test \ - --policy.device=cuda \ - --wandb.enable=true -``` - -Let's explain it: -1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/aloha_test`. -2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor sates, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. -4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon. -5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. - -For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md) - -Training should take several hours. You will find checkpoints in `outputs/train/act_aloha_test/checkpoints`. - -## Evaluate your policy - -You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=aloha \ - --control.type=record \ - --control.fps=30 \ - --control.single_task="Grasp a lego block and put it in the bin." \ - --control.repo_id=${HF_USER}/eval_act_aloha_test \ - --control.tags='["tutorial"]' \ - --control.warmup_time_s=5 \ - --control.episode_time_s=30 \ - --control.reset_time_s=30 \ - --control.num_episodes=10 \ - --control.push_to_hub=true \ - --control.policy.path=outputs/train/act_aloha_test/checkpoints/last/pretrained_model \ - --control.num_image_writer_processes=1 -``` - -As you can see, it's almost the same command as previously used to record your training dataset. Two things changed: -1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_aloha_test`). -2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_aloha_test`). -3. We use `--control.num_image_writer_processes=1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constant 30 fps during inference. Feel free to explore different values for `--control.num_image_writer_processes`. - -## More - -Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explanation. - -If you have any question or need help, please reach out on Discord in the channel `#aloha-arm`. diff --git a/examples/advanced/2_calculate_validation_loss.py b/examples/advanced/2_calculate_validation_loss.py index 47b4dd028..aac8e2e4e 100644 --- a/examples/advanced/2_calculate_validation_loss.py +++ b/examples/advanced/2_calculate_validation_loss.py @@ -66,7 +66,7 @@ def main(): print(f"Number of episodes in full dataset: {total_episodes}") print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}") print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}") - # - Load train an val datasets + # - Load train and val datasets train_dataset = LeRobotDataset( "lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps ) diff --git a/examples/backward_compatibility/replay.py b/examples/backward_compatibility/replay.py new file mode 100644 index 000000000..11684d064 --- /dev/null +++ b/examples/backward_compatibility/replay.py @@ -0,0 +1,105 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Replays the actions of an episode from a dataset on a robot. + +Example: + +```shell +python -m lerobot.replay \ + --robot.type=so100_follower \ + --robot.port=/dev/tty.usbmodem58760431541 \ + --robot.id=black \ + --dataset.repo_id=aliberts/record-test \ + --dataset.episode=2 +``` +""" + +import logging +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from pprint import pformat + +import draccus + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.robots import ( # noqa: F401 + Robot, + RobotConfig, + koch_follower, + make_robot_from_config, + so100_follower, + so101_follower, +) +from lerobot.common.utils.robot_utils import busy_wait +from lerobot.common.utils.utils import ( + init_logging, + log_say, +) + + +@dataclass +class DatasetReplayConfig: + # Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`). + repo_id: str + # Episode to replay. + episode: int + # Root directory where the dataset will be stored (e.g. 'dataset/path'). + root: str | Path | None = None + # Limit the frames per second. By default, uses the policy fps. + fps: int = 30 + + +@dataclass +class ReplayConfig: + robot: RobotConfig + dataset: DatasetReplayConfig + # Use vocal synthesis to read events. + play_sounds: bool = True + + +@draccus.wrap() +def replay(cfg: ReplayConfig): + init_logging() + logging.info(pformat(asdict(cfg))) + + robot = make_robot_from_config(cfg.robot) + dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode]) + actions = dataset.hf_dataset.select_columns("action") + robot.connect() + + log_say("Replaying episode", cfg.play_sounds, blocking=True) + for idx in range(dataset.num_frames): + start_episode_t = time.perf_counter() + + action_array = actions[idx]["action"] + action = {} + for i, name in enumerate(dataset.features["action"]["names"]): + key = f"{name.removeprefix('main_')}.pos" + action[key] = action_array[i].item() + + action["shoulder_lift.pos"] = -(action["shoulder_lift.pos"] - 90) + action["elbow_flex.pos"] -= 90 + robot.send_action(action) + + dt_s = time.perf_counter() - start_episode_t + busy_wait(1 / dataset.fps - dt_s) + + robot.disconnect() + + +if __name__ == "__main__": + replay() diff --git a/examples/lekiwi/evaluate.py b/examples/lekiwi/evaluate.py new file mode 100644 index 000000000..2a41440a3 --- /dev/null +++ b/examples/lekiwi/evaluate.py @@ -0,0 +1,32 @@ +from lerobot.common.datasets.utils import build_dataset_frame, hw_to_dataset_features +from lerobot.common.policies.act.modeling_act import ACTPolicy +from lerobot.common.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig +from lerobot.common.utils.control_utils import predict_action +from lerobot.common.utils.utils import get_safe_torch_device + +NB_CYCLES_CLIENT_CONNECTION = 1000 + +robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi") +robot = LeKiwiClient(robot_config) + +robot.connect() + +policy = ACTPolicy.from_pretrained("pepijn223/act_lekiwi_circle") +policy.reset() + +obs_features = hw_to_dataset_features(robot.observation_features, "observation") + +print("Running inference") +i = 0 +while i < NB_CYCLES_CLIENT_CONNECTION: + obs = robot.get_observation() + + observation_frame = build_dataset_frame(obs_features, obs, prefix="observation") + action_values = predict_action( + observation_frame, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp + ) + action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)} + robot.send_action(action) + i += 1 + +robot.disconnect() diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py new file mode 100644 index 000000000..405a41bd3 --- /dev/null +++ b/examples/lekiwi/record.py @@ -0,0 +1,67 @@ +import time + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.utils import hw_to_dataset_features +from lerobot.common.robots.lekiwi.config_lekiwi import LeKiwiClientConfig +from lerobot.common.robots.lekiwi.lekiwi_client import LeKiwiClient +from lerobot.common.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig +from lerobot.common.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig + +NB_CYCLES_CLIENT_CONNECTION = 250 + +leader_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem58760431551") +leader_arm = SO100Leader(leader_arm_config) + +keyboard_config = KeyboardTeleopConfig() +keyboard = KeyboardTeleop(keyboard_config) + +robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi") +robot = LeKiwiClient(robot_config) + +action_features = hw_to_dataset_features(robot.action_features, "action") +obs_features = hw_to_dataset_features(robot.observation_features, "observation") +dataset_features = {**action_features, **obs_features} + +dataset = LeRobotDataset.create( + repo_id="pepijn223/lekiwi" + str(int(time.time())), + fps=10, + features=dataset_features, + robot_type=robot.name, +) + +leader_arm.connect() +keyboard.connect() +robot.connect() + +if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected: + exit() + +print("Starting LeKiwi recording") +i = 0 +while i < NB_CYCLES_CLIENT_CONNECTION: + arm_action = leader_arm.get_action() + arm_action = {f"arm_{k}": v for k, v in arm_action.items()} + + keyboard_keys = keyboard.get_action() + + base_action = robot._from_keyboard_to_base_action(keyboard_keys) + + action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action + + action_sent = robot.send_action(action) + observation = robot.get_observation() + + frame = {**action_sent, **observation} + task = "Dummy Example Task Dataset" + + dataset.add_frame(frame, task) + i += 1 + +print("Disconnecting Teleop Devices and LeKiwi Client") +robot.disconnect() +leader_arm.disconnect() +keyboard.disconnect() + +print("Uploading dataset to the hub") +dataset.save_episode() +dataset.push_to_hub() diff --git a/examples/lekiwi/replay.py b/examples/lekiwi/replay.py new file mode 100644 index 000000000..f69092de0 --- /dev/null +++ b/examples/lekiwi/replay.py @@ -0,0 +1,25 @@ +import time + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.robots.lekiwi.config_lekiwi import LeKiwiClientConfig +from lerobot.common.robots.lekiwi.lekiwi_client import LeKiwiClient +from lerobot.common.utils.robot_utils import busy_wait + +robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi") +robot = LeKiwiClient(robot_config) + +dataset = LeRobotDataset("pepijn223/lekiwi1749025613", episodes=[0]) + +robot.connect() + +print("Replaying episode…") +for _, action_array in enumerate(dataset.hf_dataset["action"]): + t0 = time.perf_counter() + + action = {name: float(action_array[i]) for i, name in enumerate(dataset.features["action"]["names"])} + robot.send_action(action) + + busy_wait(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0)) + +print("Disconnecting LeKiwi Client") +robot.disconnect() diff --git a/examples/lekiwi/teleoperate.py b/examples/lekiwi/teleoperate.py new file mode 100644 index 000000000..2fe85d94e --- /dev/null +++ b/examples/lekiwi/teleoperate.py @@ -0,0 +1,32 @@ +from lerobot.common.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig +from lerobot.common.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop, KeyboardTeleopConfig +from lerobot.common.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig + +robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="my_lekiwi") + +teleop__arm_config = SO100LeaderConfig( + port="/dev/tty.usbmodem58760431551", + id="my_awesome_leader_arm", +) + +teleop_keyboard_config = KeyboardTeleopConfig( + id="my_laptop_keyboard", +) + +robot = LeKiwiClient(robot_config) +teleop_arm = SO100Leader(teleop__arm_config) +telep_keyboard = KeyboardTeleop(teleop_keyboard_config) +robot.connect() +teleop_arm.connect() +telep_keyboard.connect() + +while True: + observation = robot.get_observation() + + arm_action = teleop_arm.get_action() + arm_action = {f"arm_{k}": v for k, v in arm_action.items()} + + keyboard_keys = telep_keyboard.get_action() + base_action = robot._from_keyboard_to_base_action(keyboard_keys) + + robot.send_action(arm_action | base_action) diff --git a/lerobot/__init__.py b/lerobot/__init__.py index d61e4853e..11114da0a 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -168,12 +168,7 @@ available_datasets = sorted( ) # lists all available policies from `lerobot/common/policies` -available_policies = [ - "act", - "diffusion", - "tdmpc", - "vqbet", -] +available_policies = ["act", "diffusion", "tdmpc", "vqbet"] # lists all available robots from `lerobot/common/robot_devices/robots` available_robots = [ @@ -181,7 +176,7 @@ available_robots = [ "koch_bimanual", "aloha", "so100", - "moss", + "so101", ] # lists all available cameras from `lerobot/common/robot_devices/cameras` diff --git a/lerobot/calibrate.py b/lerobot/calibrate.py new file mode 100644 index 000000000..6780577ff --- /dev/null +++ b/lerobot/calibrate.py @@ -0,0 +1,84 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Helper to recalibrate your device (robot or teleoperator). + +Example: + +```shell +python -m lerobot.calibrate \ + --teleop.type=so100_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ + --teleop.id=blue +``` +""" + +import logging +from dataclasses import asdict, dataclass +from pprint import pformat + +import draccus + +from lerobot.common.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 +from lerobot.common.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.common.robots import ( # noqa: F401 + Robot, + RobotConfig, + koch_follower, + lekiwi, + make_robot_from_config, + so100_follower, + so101_follower, +) +from lerobot.common.teleoperators import ( # noqa: F401 + Teleoperator, + TeleoperatorConfig, + koch_leader, + make_teleoperator_from_config, + so100_leader, + so101_leader, +) +from lerobot.common.utils.utils import init_logging + + +@dataclass +class CalibrateConfig: + teleop: TeleoperatorConfig | None = None + robot: RobotConfig | None = None + + def __post_init__(self): + if bool(self.teleop) == bool(self.robot): + raise ValueError("Choose either a teleop or a robot.") + + self.device = self.robot if self.robot else self.teleop + + +@draccus.wrap() +def calibrate(cfg: CalibrateConfig): + init_logging() + logging.info(pformat(asdict(cfg))) + + if isinstance(cfg.device, RobotConfig): + device = make_robot_from_config(cfg.device) + elif isinstance(cfg.device, TeleoperatorConfig): + device = make_teleoperator_from_config(cfg.device) + + device.connect(calibrate=False) + device.calibrate() + device.disconnect() + + +if __name__ == "__main__": + calibrate() diff --git a/lerobot/common/cameras/__init__.py b/lerobot/common/cameras/__init__.py new file mode 100644 index 000000000..1488cd89e --- /dev/null +++ b/lerobot/common/cameras/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .camera import Camera +from .configs import CameraConfig, ColorMode, Cv2Rotation +from .utils import make_cameras_from_configs diff --git a/lerobot/common/cameras/camera.py b/lerobot/common/cameras/camera.py new file mode 100644 index 000000000..1937205b1 --- /dev/null +++ b/lerobot/common/cameras/camera.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Any, Dict, List + +import numpy as np + +from .configs import CameraConfig, ColorMode + + +class Camera(abc.ABC): + """Base class for camera implementations. + + Defines a standard interface for camera operations across different backends. + Subclasses must implement all abstract methods. + + Manages basic camera properties (FPS, resolution) and core operations: + - Connection/disconnection + - Frame capture (sync/async) + + Attributes: + fps (int | None): Configured frames per second + width (int | None): Frame width in pixels + height (int | None): Frame height in pixels + + Example: + class MyCamera(Camera): + def __init__(self, config): ... + @property + def is_connected(self) -> bool: ... + def connect(self, warmup=True): ... + # Plus other required methods + """ + + def __init__(self, config: CameraConfig): + """Initialize the camera with the given configuration. + + Args: + config: Camera configuration containing FPS and resolution. + """ + self.fps: int | None = config.fps + self.width: int | None = config.width + self.height: int | None = config.height + + @property + @abc.abstractmethod + def is_connected(self) -> bool: + """Check if the camera is currently connected. + + Returns: + bool: True if the camera is connected and ready to capture frames, + False otherwise. + """ + pass + + @staticmethod + @abc.abstractmethod + def find_cameras() -> List[Dict[str, Any]]: + """Detects available cameras connected to the system. + Returns: + List[Dict[str, Any]]: A list of dictionaries, + where each dictionary contains information about a detected camera. + """ + pass + + @abc.abstractmethod + def connect(self, warmup: bool = True) -> None: + """Establish connection to the camera. + + Args: + warmup: If True (default), captures a warmup frame before returning. Useful + for cameras that require time to adjust capture settings. + If False, skips the warmup frame. + """ + pass + + @abc.abstractmethod + def read(self, color_mode: ColorMode | None = None) -> np.ndarray: + """Capture and return a single frame from the camera. + + Args: + color_mode: Desired color mode for the output frame. If None, + uses the camera's default color mode. + + Returns: + np.ndarray: Captured frame as a numpy array. + """ + pass + + @abc.abstractmethod + def async_read(self, timeout_ms: float = ...) -> np.ndarray: + """Asynchronously capture and return a single frame from the camera. + + Args: + timeout_ms: Maximum time to wait for a frame in milliseconds. + Defaults to implementation-specific timeout. + + Returns: + np.ndarray: Captured frame as a numpy array. + """ + pass + + @abc.abstractmethod + def disconnect(self) -> None: + """Disconnect from the camera and release resources.""" + pass diff --git a/lerobot/common/robot_devices/motors/configs.py b/lerobot/common/cameras/configs.py similarity index 64% rename from lerobot/common/robot_devices/motors/configs.py rename to lerobot/common/cameras/configs.py index 0bfbaf837..0488a97ff 100644 --- a/lerobot/common/robot_devices/motors/configs.py +++ b/lerobot/common/cameras/configs.py @@ -1,3 +1,5 @@ +#!/usr/bin/env python + # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,28 +16,29 @@ import abc from dataclasses import dataclass +from enum import Enum import draccus -@dataclass -class MotorsBusConfig(draccus.ChoiceRegistry, abc.ABC): +class ColorMode(str, Enum): + RGB = "rgb" + BGR = "bgr" + + +class Cv2Rotation(int, Enum): + NO_ROTATION = 0 + ROTATE_90 = 90 + ROTATE_180 = 180 + ROTATE_270 = -90 + + +@dataclass(kw_only=True) +class CameraConfig(draccus.ChoiceRegistry, abc.ABC): + fps: int | None = None + width: int | None = None + height: int | None = None + @property def type(self) -> str: return self.get_choice_name(self.__class__) - - -@MotorsBusConfig.register_subclass("dynamixel") -@dataclass -class DynamixelMotorsBusConfig(MotorsBusConfig): - port: str - motors: dict[str, tuple[int, str]] - mock: bool = False - - -@MotorsBusConfig.register_subclass("feetech") -@dataclass -class FeetechMotorsBusConfig(MotorsBusConfig): - port: str - motors: dict[str, tuple[int, str]] - mock: bool = False diff --git a/lerobot/common/cameras/opencv/__init__.py b/lerobot/common/cameras/opencv/__init__.py new file mode 100644 index 000000000..11d3139fe --- /dev/null +++ b/lerobot/common/cameras/opencv/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .camera_opencv import OpenCVCamera +from .configuration_opencv import OpenCVCameraConfig diff --git a/lerobot/common/cameras/opencv/camera_opencv.py b/lerobot/common/cameras/opencv/camera_opencv.py new file mode 100644 index 000000000..3e9370fc4 --- /dev/null +++ b/lerobot/common/cameras/opencv/camera_opencv.py @@ -0,0 +1,482 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Provides the OpenCVCamera class for capturing frames from cameras using OpenCV. +""" + +import logging +import math +import platform +import time +from pathlib import Path +from threading import Event, Lock, Thread +from typing import Any, Dict, List + +import cv2 +import numpy as np + +from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +from ..camera import Camera +from ..utils import get_cv2_backend, get_cv2_rotation +from .configuration_opencv import ColorMode, OpenCVCameraConfig + +# NOTE(Steven): The maximum opencv device index depends on your operating system. For instance, +# if you have 3 cameras, they should be associated to index 0, 1, and 2. This is the case +# on MacOS. However, on Ubuntu, the indices are different like 6, 16, 23. +# When you change the USB port or reboot the computer, the operating system might +# treat the same cameras as new devices. Thus we select a higher bound to search indices. +MAX_OPENCV_INDEX = 60 + +logger = logging.getLogger(__name__) + + +class OpenCVCamera(Camera): + """ + Manages camera interactions using OpenCV for efficient frame recording. + + This class provides a high-level interface to connect to, configure, and read + frames from cameras compatible with OpenCV's VideoCapture. It supports both + synchronous and asynchronous frame reading. + + An OpenCVCamera instance requires a camera index (e.g., 0) or a device path + (e.g., '/dev/video0' on Linux). Camera indices can be unstable across reboots + or port changes, especially on Linux. Use the provided utility script to find + available camera indices or paths: + ```bash + python -m lerobot.find_cameras opencv + ``` + + The camera's default settings (FPS, resolution, color mode) are used unless + overridden in the configuration. + + Example: + ```python + from lerobot.common.cameras.opencv import OpenCVCamera + from lerobot.common.cameras.configuration_opencv import OpenCVCameraConfig, ColorMode, Cv2Rotation + + # Basic usage with camera index 0 + config = OpenCVCameraConfig(index_or_path=0) + camera = OpenCVCamera(config) + camera.connect() + + # Read 1 frame synchronously + color_image = camera.read() + print(color_image.shape) + + # Read 1 frame asynchronously + async_image = camera.async_read() + + # When done, properly disconnect the camera using + camera.disconnect() + + # Example with custom settings + custom_config = OpenCVCameraConfig( + index_or_path='/dev/video0', # Or use an index + fps=30, + width=1280, + height=720, + color_mode=ColorMode.RGB, + rotation=Cv2Rotation.ROTATE_90 + ) + custom_camera = OpenCVCamera(custom_config) + # ... connect, read, disconnect ... + ``` + """ + + def __init__(self, config: OpenCVCameraConfig): + """ + Initializes the OpenCVCamera instance. + + Args: + config: The configuration settings for the camera. + """ + super().__init__(config) + + self.config = config + self.index_or_path = config.index_or_path + + self.fps = config.fps + self.color_mode = config.color_mode + self.warmup_s = config.warmup_s + + self.videocapture: cv2.VideoCapture | None = None + + self.thread: Thread | None = None + self.stop_event: Event | None = None + self.frame_lock: Lock = Lock() + self.latest_frame: np.ndarray | None = None + self.new_frame_event: Event = Event() + + self.rotation: int | None = get_cv2_rotation(config.rotation) + self.backend: int = get_cv2_backend() + + if self.height and self.width: + self.capture_width, self.capture_height = self.width, self.height + if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]: + self.capture_width, self.capture_height = self.height, self.width + + def __str__(self) -> str: + return f"{self.__class__.__name__}({self.index_or_path})" + + @property + def is_connected(self) -> bool: + """Checks if the camera is currently connected and opened.""" + return isinstance(self.videocapture, cv2.VideoCapture) and self.videocapture.isOpened() + + def connect(self, warmup: bool = True): + """ + Connects to the OpenCV camera specified in the configuration. + + Initializes the OpenCV VideoCapture object, sets desired camera properties + (FPS, width, height), and performs initial checks. + + Raises: + DeviceAlreadyConnectedError: If the camera is already connected. + ConnectionError: If the specified camera index/path is not found or the camera is found but fails to open. + RuntimeError: If the camera opens but fails to apply requested FPS/resolution settings. + """ + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} is already connected.") + + # Use 1 thread for OpenCV operations to avoid potential conflicts or + # blocking in multi-threaded applications, especially during data collection. + cv2.setNumThreads(1) + + self.videocapture = cv2.VideoCapture(self.index_or_path, self.backend) + + if not self.videocapture.isOpened(): + self.videocapture.release() + self.videocapture = None + raise ConnectionError( + f"Failed to open {self}." + f"Run `python -m lerobot.find_cameras opencv` to find available cameras." + ) + + self._configure_capture_settings() + + if warmup: + start_time = time.time() + while time.time() - start_time < self.warmup_s: + self.read() + time.sleep(0.1) + + logger.info(f"{self} connected.") + + def _configure_capture_settings(self) -> None: + """ + Applies the specified FPS, width, and height settings to the connected camera. + + This method attempts to set the camera properties via OpenCV. It checks if + the camera successfully applied the settings and raises an error if not. + + Args: + fps: The desired frames per second. If None, the setting is skipped. + width: The desired capture width. If None, the setting is skipped. + height: The desired capture height. If None, the setting is skipped. + + Raises: + RuntimeError: If the camera fails to set any of the specified properties + to the requested value. + DeviceNotConnectedError: If the camera is not connected when attempting + to configure settings. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"Cannot configure settings for {self} as it is not connected.") + + if self.fps is None: + self.fps = self.videocapture.get(cv2.CAP_PROP_FPS) + else: + self._validate_fps() + + default_width = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_WIDTH))) + default_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT))) + + if self.width is None or self.height is None: + self.width, self.height = default_width, default_height + self.capture_width, self.capture_height = default_width, default_height + if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]: + self.width, self.height = default_height, default_width + self.capture_width, self.capture_height = default_width, default_height + else: + self._validate_width_and_height() + + def _validate_fps(self) -> None: + """Validates and sets the camera's frames per second (FPS).""" + + success = self.videocapture.set(cv2.CAP_PROP_FPS, float(self.fps)) + actual_fps = self.videocapture.get(cv2.CAP_PROP_FPS) + # Use math.isclose for robust float comparison + if not success or not math.isclose(self.fps, actual_fps, rel_tol=1e-3): + raise RuntimeError(f"{self} failed to set fps={self.fps} ({actual_fps=}).") + + def _validate_width_and_height(self) -> None: + """Validates and sets the camera's frame capture width and height.""" + + width_success = self.videocapture.set(cv2.CAP_PROP_FRAME_WIDTH, float(self.capture_width)) + height_success = self.videocapture.set(cv2.CAP_PROP_FRAME_HEIGHT, float(self.capture_height)) + + actual_width = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_WIDTH))) + if not width_success or self.capture_width != actual_width: + raise RuntimeError( + f"{self} failed to set capture_width={self.capture_width} ({actual_width=}, {width_success=})." + ) + + actual_height = int(round(self.videocapture.get(cv2.CAP_PROP_FRAME_HEIGHT))) + if not height_success or self.capture_height != actual_height: + raise RuntimeError( + f"{self} failed to set capture_height={self.capture_height} ({actual_height=}, {height_success=})." + ) + + @staticmethod + def find_cameras() -> List[Dict[str, Any]]: + """ + Detects available OpenCV cameras connected to the system. + + On Linux, it scans '/dev/video*' paths. On other systems (like macOS, Windows), + it checks indices from 0 up to `MAX_OPENCV_INDEX`. + + Returns: + List[Dict[str, Any]]: A list of dictionaries, + where each dictionary contains 'type', 'id' (port index or path), + and the default profile properties (width, height, fps, format). + """ + found_cameras_info = [] + + if platform.system() == "Linux": + possible_paths = sorted(Path("/dev").glob("video*"), key=lambda p: p.name) + targets_to_scan = [str(p) for p in possible_paths] + else: + targets_to_scan = list(range(MAX_OPENCV_INDEX)) + + for target in targets_to_scan: + camera = cv2.VideoCapture(target) + if camera.isOpened(): + default_width = int(camera.get(cv2.CAP_PROP_FRAME_WIDTH)) + default_height = int(camera.get(cv2.CAP_PROP_FRAME_HEIGHT)) + default_fps = camera.get(cv2.CAP_PROP_FPS) + default_format = camera.get(cv2.CAP_PROP_FORMAT) + camera_info = { + "name": f"OpenCV Camera @ {target}", + "type": "OpenCV", + "id": target, + "backend_api": camera.getBackendName(), + "default_stream_profile": { + "format": default_format, + "width": default_width, + "height": default_height, + "fps": default_fps, + }, + } + + found_cameras_info.append(camera_info) + camera.release() + + return found_cameras_info + + def read(self, color_mode: ColorMode | None = None) -> np.ndarray: + """ + Reads a single frame synchronously from the camera. + + This is a blocking call. It waits for the next available frame from the + camera hardware via OpenCV. + + Args: + color_mode (Optional[ColorMode]): If specified, overrides the default + color mode (`self.color_mode`) for this read operation (e.g., + request RGB even if default is BGR). + + Returns: + np.ndarray: The captured frame as a NumPy array in the format + (height, width, channels), using the specified or default + color mode and applying any configured rotation. + + Raises: + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If reading the frame from the camera fails or if the + received frame dimensions don't match expectations before rotation. + ValueError: If an invalid `color_mode` is requested. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + start_time = time.perf_counter() + + ret, frame = self.videocapture.read() + + if not ret or frame is None: + raise RuntimeError(f"{self} read failed (status={ret}).") + + processed_frame = self._postprocess_image(frame, color_mode) + + read_duration_ms = (time.perf_counter() - start_time) * 1e3 + logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") + + return processed_frame + + def _postprocess_image(self, image: np.ndarray, color_mode: ColorMode | None = None) -> np.ndarray: + """ + Applies color conversion, dimension validation, and rotation to a raw frame. + + Args: + image (np.ndarray): The raw image frame (expected BGR format from OpenCV). + color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None, + uses the instance's default `self.color_mode`. + + Returns: + np.ndarray: The processed image frame. + + Raises: + ValueError: If the requested `color_mode` is invalid. + RuntimeError: If the raw frame dimensions do not match the configured + `width` and `height`. + """ + requested_color_mode = self.color_mode if color_mode is None else color_mode + + if requested_color_mode not in (ColorMode.RGB, ColorMode.BGR): + raise ValueError( + f"Invalid color mode '{requested_color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}." + ) + + h, w, c = image.shape + + if h != self.capture_height or w != self.capture_width: + raise RuntimeError( + f"{self} frame width={w} or height={h} do not match configured width={self.capture_width} or height={self.capture_height}." + ) + + if c != 3: + raise RuntimeError(f"{self} frame channels={c} do not match expected 3 channels (RGB/BGR).") + + processed_image = image + if requested_color_mode == ColorMode.RGB: + processed_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]: + processed_image = cv2.rotate(processed_image, self.rotation) + + return processed_image + + def _read_loop(self): + """ + Internal loop run by the background thread for asynchronous reading. + + On each iteration: + 1. Reads a color frame + 2. Stores result in latest_frame (thread-safe) + 3. Sets new_frame_event to notify listeners + + Stops on DeviceNotConnectedError, logs other errors and continues. + """ + while not self.stop_event.is_set(): + try: + color_image = self.read() + + with self.frame_lock: + self.latest_frame = color_image + self.new_frame_event.set() + + except DeviceNotConnectedError: + break + except Exception as e: + logger.warning(f"Error reading frame in background thread for {self}: {e}") + + def _start_read_thread(self) -> None: + """Starts or restarts the background read thread if it's not running.""" + if self.thread is not None and self.thread.is_alive(): + self.thread.join(timeout=0.1) + if self.stop_event is not None: + self.stop_event.set() + + self.stop_event = Event() + self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop") + self.thread.daemon = True + self.thread.start() + + def _stop_read_thread(self) -> None: + """Signals the background read thread to stop and waits for it to join.""" + if self.stop_event is not None: + self.stop_event.set() + + if self.thread is not None and self.thread.is_alive(): + self.thread.join(timeout=2.0) + + self.thread = None + self.stop_event = None + + def async_read(self, timeout_ms: float = 200) -> np.ndarray: + """ + Reads the latest available frame asynchronously. + + This method retrieves the most recent frame captured by the background + read thread. It does not block waiting for the camera hardware directly, + but may wait up to timeout_ms for the background thread to provide a frame. + + Args: + timeout_ms (float): Maximum time in milliseconds to wait for a frame + to become available. Defaults to 200ms (0.2 seconds). + + Returns: + np.ndarray: The latest captured frame as a NumPy array in the format + (height, width, channels), processed according to configuration. + + Raises: + DeviceNotConnectedError: If the camera is not connected. + TimeoutError: If no frame becomes available within the specified timeout. + RuntimeError: If an unexpected error occurs. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.thread is None or not self.thread.is_alive(): + self._start_read_thread() + + if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0): + thread_alive = self.thread is not None and self.thread.is_alive() + raise TimeoutError( + f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. " + f"Read thread alive: {thread_alive}." + ) + + with self.frame_lock: + frame = self.latest_frame + self.new_frame_event.clear() + + if frame is None: + raise RuntimeError(f"Internal error: Event set but no frame available for {self}.") + + return frame + + def disconnect(self): + """ + Disconnects from the camera and cleans up resources. + + Stops the background read thread (if running) and releases the OpenCV + VideoCapture object. + + Raises: + DeviceNotConnectedError: If the camera is already disconnected. + """ + if not self.is_connected and self.thread is None: + raise DeviceNotConnectedError(f"{self} not connected.") + + if self.thread is not None: + self._stop_read_thread() + + if self.videocapture is not None: + self.videocapture.release() + self.videocapture = None + + logger.info(f"{self} disconnected.") diff --git a/lerobot/common/cameras/opencv/configuration_opencv.py b/lerobot/common/cameras/opencv/configuration_opencv.py new file mode 100644 index 000000000..3ac92de36 --- /dev/null +++ b/lerobot/common/cameras/opencv/configuration_opencv.py @@ -0,0 +1,73 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from pathlib import Path + +from ..configs import CameraConfig, ColorMode, Cv2Rotation + + +@CameraConfig.register_subclass("opencv") +@dataclass +class OpenCVCameraConfig(CameraConfig): + """Configuration class for OpenCV-based camera devices or video files. + + This class provides configuration options for cameras accessed through OpenCV, + supporting both physical camera devices and video files. It includes settings + for resolution, frame rate, color mode, and image rotation. + + Example configurations: + ```python + # Basic configurations + OpenCVCameraConfig(0, 30, 1280, 720) # 1280x720 @ 30FPS + OpenCVCameraConfig(/dev/video4, 60, 640, 480) # 640x480 @ 60FPS + + # Advanced configurations + OpenCVCameraConfig(128422271347, 30, 640, 480, rotation=Cv2Rotation.ROTATE_90) # With 90° rotation + ``` + + Attributes: + index_or_path: Either an integer representing the camera device index, + or a Path object pointing to a video file. + fps: Requested frames per second for the color stream. + width: Requested frame width in pixels for the color stream. + height: Requested frame height in pixels for the color stream. + color_mode: Color mode for image output (RGB or BGR). Defaults to RGB. + rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation. + warmup_s: Time reading frames before returning from connect (in seconds) + + Note: + - Only 3-channel color output (RGB/BGR) is currently supported. + """ + + index_or_path: int | Path + color_mode: ColorMode = ColorMode.RGB + rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION + warmup_s: int = 1 + + def __post_init__(self): + if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): + raise ValueError( + f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided." + ) + + if self.rotation not in ( + Cv2Rotation.NO_ROTATION, + Cv2Rotation.ROTATE_90, + Cv2Rotation.ROTATE_180, + Cv2Rotation.ROTATE_270, + ): + raise ValueError( + f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided." + ) diff --git a/lerobot/common/cameras/realsense/__init__.py b/lerobot/common/cameras/realsense/__init__.py new file mode 100644 index 000000000..67f2f4000 --- /dev/null +++ b/lerobot/common/cameras/realsense/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .camera_realsense import RealSenseCamera +from .configuration_realsense import RealSenseCameraConfig diff --git a/lerobot/common/cameras/realsense/camera_realsense.py b/lerobot/common/cameras/realsense/camera_realsense.py new file mode 100644 index 000000000..2bcbee75c --- /dev/null +++ b/lerobot/common/cameras/realsense/camera_realsense.py @@ -0,0 +1,556 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Provides the RealSenseCamera class for capturing frames from Intel RealSense cameras. +""" + +import logging +import time +from threading import Event, Lock, Thread +from typing import Any, Dict, List + +import cv2 +import numpy as np + +try: + import pyrealsense2 as rs +except Exception as e: + logging.info(f"Could not import realsense: {e}") + +from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +from ..camera import Camera +from ..configs import ColorMode +from ..utils import get_cv2_rotation +from .configuration_realsense import RealSenseCameraConfig + +logger = logging.getLogger(__name__) + + +class RealSenseCamera(Camera): + """ + Manages interactions with Intel RealSense cameras for frame and depth recording. + + This class provides an interface similar to `OpenCVCamera` but tailored for + RealSense devices, leveraging the `pyrealsense2` library. It uses the camera's + unique serial number for identification, offering more stability than device + indices, especially on Linux. It also supports capturing depth maps alongside + color frames. + + Use the provided utility script to find available camera indices and default profiles: + ```bash + python -m lerobot.find_cameras realsense + ``` + + A `RealSenseCamera` instance requires a configuration object specifying the + camera's serial number or a unique device name. If using the name, ensure only + one camera with that name is connected. + + The camera's default settings (FPS, resolution, color mode) from the stream + profile are used unless overridden in the configuration. + + Example: + ```python + from lerobot.common.cameras.realsense import RealSenseCamera, RealSenseCameraConfig + from lerobot.common.cameras import ColorMode, Cv2Rotation + + # Basic usage with serial number + config = RealSenseCameraConfig(serial_number_or_name="0123456789") # Replace with actual SN + camera = RealSenseCamera(config) + camera.connect() + + # Read 1 frame synchronously + color_image = camera.read() + print(color_image.shape) + + # Read 1 frame asynchronously + async_image = camera.async_read() + + # When done, properly disconnect the camera using + camera.disconnect() + + # Example with depth capture and custom settings + custom_config = RealSenseCameraConfig( + serial_number_or_name="0123456789", # Replace with actual SN + fps=30, + width=1280, + height=720, + color_mode=ColorMode.BGR, # Request BGR output + rotation=Cv2Rotation.NO_ROTATION, + use_depth=True + ) + depth_camera = RealSenseCamera(custom_config) + depth_camera.connect() + + # Read 1 depth frame + depth_map = depth_camera.read_depth() + + # Example using a unique camera name + name_config = RealSenseCameraConfig(serial_number_or_name="Intel RealSense D435") # If unique + name_camera = RealSenseCamera(name_config) + # ... connect, read, disconnect ... + ``` + """ + + def __init__(self, config: RealSenseCameraConfig): + """ + Initializes the RealSenseCamera instance. + + Args: + config: The configuration settings for the camera. + """ + + super().__init__(config) + + self.config = config + + if config.serial_number_or_name.isdigit(): + self.serial_number = config.serial_number_or_name + else: + self.serial_number = self._find_serial_number_from_name(config.serial_number_or_name) + + self.fps = config.fps + self.color_mode = config.color_mode + self.use_depth = config.use_depth + self.warmup_s = config.warmup_s + + self.rs_pipeline: rs.pipeline | None = None + self.rs_profile: rs.pipeline_profile | None = None + + self.thread: Thread | None = None + self.stop_event: Event | None = None + self.frame_lock: Lock = Lock() + self.latest_frame: np.ndarray | None = None + self.new_frame_event: Event = Event() + + self.rotation: int | None = get_cv2_rotation(config.rotation) + + if self.height and self.width: + self.capture_width, self.capture_height = self.width, self.height + if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]: + self.capture_width, self.capture_height = self.height, self.width + + def __str__(self) -> str: + return f"{self.__class__.__name__}({self.serial_number})" + + @property + def is_connected(self) -> bool: + """Checks if the camera pipeline is started and streams are active.""" + return self.rs_pipeline is not None and self.rs_profile is not None + + def connect(self, warmup: bool = True): + """ + Connects to the RealSense camera specified in the configuration. + + Initializes the RealSense pipeline, configures the required streams (color + and optionally depth), starts the pipeline, and validates the actual stream settings. + + Raises: + DeviceAlreadyConnectedError: If the camera is already connected. + ValueError: If the configuration is invalid (e.g., missing serial/name, name not unique). + ConnectionError: If the camera is found but fails to start the pipeline or no RealSense devices are detected at all. + RuntimeError: If the pipeline starts but fails to apply requested settings. + """ + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} is already connected.") + + self.rs_pipeline = rs.pipeline() + rs_config = rs.config() + self._configure_rs_pipeline_config(rs_config) + + try: + self.rs_profile = self.rs_pipeline.start(rs_config) + except RuntimeError as e: + self.rs_profile = None + self.rs_pipeline = None + raise ConnectionError( + f"Failed to open {self}." + "Run `python -m lerobot.find_cameras realsense` to find available cameras." + ) from e + + self._configure_capture_settings() + + if warmup: + time.sleep( + 1 + ) # NOTE(Steven): RS cameras need a bit of time to warm up before the first read. If we don't wait, the first read from the warmup will raise. + start_time = time.time() + while time.time() - start_time < self.warmup_s: + self.read() + time.sleep(0.1) + + logger.info(f"{self} connected.") + + @staticmethod + def find_cameras() -> List[Dict[str, Any]]: + """ + Detects available Intel RealSense cameras connected to the system. + + Returns: + List[Dict[str, Any]]: A list of dictionaries, + where each dictionary contains 'type', 'id' (serial number), 'name', + firmware version, USB type, and other available specs, and the default profile properties (width, height, fps, format). + + Raises: + OSError: If pyrealsense2 is not installed. + ImportError: If pyrealsense2 is not installed. + """ + found_cameras_info = [] + context = rs.context() + devices = context.query_devices() + + for device in devices: + camera_info = { + "name": device.get_info(rs.camera_info.name), + "type": "RealSense", + "id": device.get_info(rs.camera_info.serial_number), + "firmware_version": device.get_info(rs.camera_info.firmware_version), + "usb_type_descriptor": device.get_info(rs.camera_info.usb_type_descriptor), + "physical_port": device.get_info(rs.camera_info.physical_port), + "product_id": device.get_info(rs.camera_info.product_id), + "product_line": device.get_info(rs.camera_info.product_line), + } + + # Get stream profiles for each sensor + sensors = device.query_sensors() + for sensor in sensors: + profiles = sensor.get_stream_profiles() + + for profile in profiles: + if profile.is_video_stream_profile() and profile.is_default(): + vprofile = profile.as_video_stream_profile() + stream_info = { + "stream_type": vprofile.stream_name(), + "format": vprofile.format().name, + "width": vprofile.width(), + "height": vprofile.height(), + "fps": vprofile.fps(), + } + camera_info["default_stream_profile"] = stream_info + + found_cameras_info.append(camera_info) + + return found_cameras_info + + def _find_serial_number_from_name(self, name: str) -> str: + """Finds the serial number for a given unique camera name.""" + camera_infos = self.find_cameras() + found_devices = [cam for cam in camera_infos if str(cam["name"]) == name] + + if not found_devices: + available_names = [cam["name"] for cam in camera_infos] + raise ValueError( + f"No RealSense camera found with name '{name}'. Available camera names: {available_names}" + ) + + if len(found_devices) > 1: + serial_numbers = [dev["serial_number"] for dev in found_devices] + raise ValueError( + f"Multiple RealSense cameras found with name '{name}'. " + f"Please use a unique serial number instead. Found SNs: {serial_numbers}" + ) + + serial_number = str(found_devices[0]["serial_number"]) + return serial_number + + def _configure_rs_pipeline_config(self, rs_config): + """Creates and configures the RealSense pipeline configuration object.""" + rs.config.enable_device(rs_config, self.serial_number) + + if self.width and self.height and self.fps: + rs_config.enable_stream( + rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps + ) + if self.use_depth: + rs_config.enable_stream( + rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps + ) + else: + rs_config.enable_stream(rs.stream.color) + if self.use_depth: + rs_config.enable_stream(rs.stream.depth) + + def _configure_capture_settings(self) -> None: + """Sets fps, width, and height from device stream if not already configured. + + Uses the color stream profile to update unset attributes. Handles rotation by + swapping width/height when needed. Original capture dimensions are always stored. + + Raises: + DeviceNotConnectedError: If device is not connected. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"Cannot validate settings for {self} as it is not connected.") + + stream = self.rs_profile.get_stream(rs.stream.color).as_video_stream_profile() + + if self.fps is None: + self.fps = stream.fps() + + if self.width is None or self.height is None: + actual_width = int(round(stream.width())) + actual_height = int(round(stream.height())) + if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]: + self.width, self.height = actual_height, actual_width + self.capture_width, self.capture_height = actual_width, actual_height + else: + self.width, self.height = actual_width, actual_height + self.capture_width, self.capture_height = actual_width, actual_height + + def read_depth(self, timeout_ms: int = 200) -> np.ndarray: + """ + Reads a single frame (depth) synchronously from the camera. + + This is a blocking call. It waits for a coherent set of frames (depth) + from the camera hardware via the RealSense pipeline. + + Args: + timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms. + + Returns: + np.ndarray: The depth map as a NumPy array (height, width) + of type `np.uint16` (raw depth values in millimeters) and rotation. + + Raises: + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If reading frames from the pipeline fails or frames are invalid. + """ + + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + if not self.use_depth: + raise RuntimeError( + f"Failed to capture depth frame '.read_depth()'. Depth stream is not enabled for {self}." + ) + + start_time = time.perf_counter() + + ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms) + + if not ret or frame is None: + raise RuntimeError(f"{self} read_depth failed (status={ret}).") + + depth_frame = frame.get_depth_frame() + depth_map = np.asanyarray(depth_frame.get_data()) + + depth_map_processed = self._postprocess_image(depth_map, depth_frame=True) + + read_duration_ms = (time.perf_counter() - start_time) * 1e3 + logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") + + return depth_map_processed + + def read(self, color_mode: ColorMode | None = None, timeout_ms: int = 200) -> np.ndarray: + """ + Reads a single frame (color) synchronously from the camera. + + This is a blocking call. It waits for a coherent set of frames (color) + from the camera hardware via the RealSense pipeline. + + Args: + timeout_ms (int): Maximum time in milliseconds to wait for a frame. Defaults to 200ms. + + Returns: + np.ndarray: The captured color frame as a NumPy array + (height, width, channels), processed according to `color_mode` and rotation. + + Raises: + DeviceNotConnectedError: If the camera is not connected. + RuntimeError: If reading frames from the pipeline fails or frames are invalid. + ValueError: If an invalid `color_mode` is requested. + """ + + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + start_time = time.perf_counter() + + ret, frame = self.rs_pipeline.try_wait_for_frames(timeout_ms=timeout_ms) + + if not ret or frame is None: + raise RuntimeError(f"{self} read failed (status={ret}).") + + color_frame = frame.get_color_frame() + color_image_raw = np.asanyarray(color_frame.get_data()) + + color_image_processed = self._postprocess_image(color_image_raw, color_mode) + + read_duration_ms = (time.perf_counter() - start_time) * 1e3 + logger.debug(f"{self} read took: {read_duration_ms:.1f}ms") + + return color_image_processed + + def _postprocess_image( + self, image: np.ndarray, color_mode: ColorMode | None = None, depth_frame: bool = False + ) -> np.ndarray: + """ + Applies color conversion, dimension validation, and rotation to a raw color frame. + + Args: + image (np.ndarray): The raw image frame (expected RGB format from RealSense). + color_mode (Optional[ColorMode]): The target color mode (RGB or BGR). If None, + uses the instance's default `self.color_mode`. + + Returns: + np.ndarray: The processed image frame according to `self.color_mode` and `self.rotation`. + + Raises: + ValueError: If the requested `color_mode` is invalid. + RuntimeError: If the raw frame dimensions do not match the configured + `width` and `height`. + """ + + if color_mode and color_mode not in (ColorMode.RGB, ColorMode.BGR): + raise ValueError( + f"Invalid requested color mode '{color_mode}'. Expected {ColorMode.RGB} or {ColorMode.BGR}." + ) + + if depth_frame: + h, w = image.shape + else: + h, w, c = image.shape + + if c != 3: + raise RuntimeError(f"{self} frame channels={c} do not match expected 3 channels (RGB/BGR).") + + if h != self.capture_height or w != self.capture_width: + raise RuntimeError( + f"{self} frame width={w} or height={h} do not match configured width={self.capture_width} or height={self.capture_height}." + ) + + processed_image = image + if self.color_mode == ColorMode.BGR: + processed_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + + if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]: + processed_image = cv2.rotate(processed_image, self.rotation) + + return processed_image + + def _read_loop(self): + """ + Internal loop run by the background thread for asynchronous reading. + + On each iteration: + 1. Reads a color frame with 500ms timeout + 2. Stores result in latest_frame (thread-safe) + 3. Sets new_frame_event to notify listeners + + Stops on DeviceNotConnectedError, logs other errors and continues. + """ + while not self.stop_event.is_set(): + try: + color_image = self.read(timeout_ms=500) + + with self.frame_lock: + self.latest_frame = color_image + self.new_frame_event.set() + + except DeviceNotConnectedError: + break + except Exception as e: + logger.warning(f"Error reading frame in background thread for {self}: {e}") + + def _start_read_thread(self) -> None: + """Starts or restarts the background read thread if it's not running.""" + if self.thread is not None and self.thread.is_alive(): + self.thread.join(timeout=0.1) + if self.stop_event is not None: + self.stop_event.set() + + self.stop_event = Event() + self.thread = Thread(target=self._read_loop, args=(), name=f"{self}_read_loop") + self.thread.daemon = True + self.thread.start() + + def _stop_read_thread(self): + """Signals the background read thread to stop and waits for it to join.""" + if self.stop_event is not None: + self.stop_event.set() + + if self.thread is not None and self.thread.is_alive(): + self.thread.join(timeout=2.0) + + self.thread = None + self.stop_event = None + + # NOTE(Steven): Missing implementation for depth for now + def async_read(self, timeout_ms: float = 200) -> np.ndarray: + """ + Reads the latest available frame data (color) asynchronously. + + This method retrieves the most recent color frame captured by the background + read thread. It does not block waiting for the camera hardware directly, + but may wait up to timeout_ms for the background thread to provide a frame. + + Args: + timeout_ms (float): Maximum time in milliseconds to wait for a frame + to become available. Defaults to 200ms (0.2 seconds). + + Returns: + np.ndarray: + The latest captured frame data (color image), processed according to configuration. + + Raises: + DeviceNotConnectedError: If the camera is not connected. + TimeoutError: If no frame data becomes available within the specified timeout. + RuntimeError: If the background thread died unexpectedly or another error occurs. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.thread is None or not self.thread.is_alive(): + self._start_read_thread() + + if not self.new_frame_event.wait(timeout=timeout_ms / 1000.0): + thread_alive = self.thread is not None and self.thread.is_alive() + raise TimeoutError( + f"Timed out waiting for frame from camera {self} after {timeout_ms} ms. " + f"Read thread alive: {thread_alive}." + ) + + with self.frame_lock: + frame = self.latest_frame + self.new_frame_event.clear() + + if frame is None: + raise RuntimeError(f"Internal error: Event set but no frame available for {self}.") + + return frame + + def disconnect(self): + """ + Disconnects from the camera, stops the pipeline, and cleans up resources. + + Stops the background read thread (if running) and stops the RealSense pipeline. + + Raises: + DeviceNotConnectedError: If the camera is already disconnected (pipeline not running). + """ + + if not self.is_connected and self.thread is None: + raise DeviceNotConnectedError( + f"Attempted to disconnect {self}, but it appears already disconnected." + ) + + if self.thread is not None: + self._stop_read_thread() + + if self.rs_pipeline is not None: + self.rs_pipeline.stop() + self.rs_pipeline = None + self.rs_profile = None + + logger.info(f"{self} disconnected.") diff --git a/lerobot/common/cameras/realsense/configuration_realsense.py b/lerobot/common/cameras/realsense/configuration_realsense.py new file mode 100644 index 000000000..82e7c0d36 --- /dev/null +++ b/lerobot/common/cameras/realsense/configuration_realsense.py @@ -0,0 +1,82 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from ..configs import CameraConfig, ColorMode, Cv2Rotation + + +@CameraConfig.register_subclass("intelrealsense") +@dataclass +class RealSenseCameraConfig(CameraConfig): + """Configuration class for Intel RealSense cameras. + + This class provides specialized configuration options for Intel RealSense cameras, + including support for depth sensing and device identification via serial number or name. + + Example configurations for Intel RealSense D405: + ```python + # Basic configurations + RealSenseCameraConfig("0123456789", 30, 1280, 720) # 1280x720 @ 30FPS + RealSenseCameraConfig("0123456789", 60, 640, 480) # 640x480 @ 60FPS + + # Advanced configurations + RealSenseCameraConfig("0123456789", 30, 640, 480, use_depth=True) # With depth sensing + RealSenseCameraConfig("0123456789", 30, 640, 480, rotation=Cv2Rotation.ROTATE_90) # With 90° rotation + ``` + + Attributes: + fps: Requested frames per second for the color stream. + width: Requested frame width in pixels for the color stream. + height: Requested frame height in pixels for the color stream. + serial_number_or_name: Unique serial number or human-readable name to identify the camera. + color_mode: Color mode for image output (RGB or BGR). Defaults to RGB. + use_depth: Whether to enable depth stream. Defaults to False. + rotation: Image rotation setting (0°, 90°, 180°, or 270°). Defaults to no rotation. + warmup_s: Time reading frames before returning from connect (in seconds) + + Note: + - Either name or serial_number must be specified. + - Depth stream configuration (if enabled) will use the same FPS as the color stream. + - The actual resolution and FPS may be adjusted by the camera to the nearest supported mode. + - For `fps`, `width` and `height`, either all of them need to be set, or none of them. + """ + + serial_number_or_name: str + color_mode: ColorMode = ColorMode.RGB + use_depth: bool = False + rotation: Cv2Rotation = Cv2Rotation.NO_ROTATION + warmup_s: int = 1 + + def __post_init__(self): + if self.color_mode not in (ColorMode.RGB, ColorMode.BGR): + raise ValueError( + f"`color_mode` is expected to be {ColorMode.RGB.value} or {ColorMode.BGR.value}, but {self.color_mode} is provided." + ) + + if self.rotation not in ( + Cv2Rotation.NO_ROTATION, + Cv2Rotation.ROTATE_90, + Cv2Rotation.ROTATE_180, + Cv2Rotation.ROTATE_270, + ): + raise ValueError( + f"`rotation` is expected to be in {(Cv2Rotation.NO_ROTATION, Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_180, Cv2Rotation.ROTATE_270)}, but {self.rotation} is provided." + ) + + values = (self.fps, self.width, self.height) + if any(v is not None for v in values) and any(v is None for v in values): + raise ValueError( + "For `fps`, `width` and `height`, either all of them need to be set, or none of them." + ) diff --git a/lerobot/common/cameras/utils.py b/lerobot/common/cameras/utils.py new file mode 100644 index 000000000..f8bbd6e70 --- /dev/null +++ b/lerobot/common/cameras/utils.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform +from pathlib import Path +from typing import TypeAlias + +from .camera import Camera +from .configs import CameraConfig, Cv2Rotation + +IndexOrPath: TypeAlias = int | Path + + +def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[str, Camera]: + cameras = {} + + for key, cfg in camera_configs.items(): + if cfg.type == "opencv": + from .opencv import OpenCVCamera + + cameras[key] = OpenCVCamera(cfg) + + elif cfg.type == "intelrealsense": + from .realsense.camera_realsense import RealSenseCamera + + cameras[key] = RealSenseCamera(cfg) + else: + raise ValueError(f"The motor type '{cfg.type}' is not valid.") + + return cameras + + +def get_cv2_rotation(rotation: Cv2Rotation) -> int | None: + import cv2 + + if rotation == Cv2Rotation.ROTATE_90: + return cv2.ROTATE_90_CLOCKWISE + elif rotation == Cv2Rotation.ROTATE_180: + return cv2.ROTATE_180 + elif rotation == Cv2Rotation.ROTATE_270: + return cv2.ROTATE_90_COUNTERCLOCKWISE + else: + return None + + +def get_cv2_backend() -> int: + import cv2 + + if platform.system() == "Windows": + return cv2.CAP_AVFOUNDATION + else: + return cv2.CAP_ANY diff --git a/lerobot/common/constants.py b/lerobot/common/constants.py index 973595cdf..30777239e 100644 --- a/lerobot/common/constants.py +++ b/lerobot/common/constants.py @@ -17,11 +17,16 @@ from pathlib import Path from huggingface_hub.constants import HF_HOME -OBS_ENV = "observation.environment_state" -OBS_ROBOT = "observation.state" +OBS_ENV_STATE = "observation.environment_state" +OBS_STATE = "observation.state" OBS_IMAGE = "observation.image" OBS_IMAGES = "observation.images" ACTION = "action" +REWARD = "next.reward" + +ROBOTS = "robots" +ROBOT_TYPE = "robot_type" +TELEOPERATORS = "teleoperators" # files & directories CHECKPOINTS_DIR = "checkpoints" @@ -34,12 +39,16 @@ OPTIMIZER_STATE = "optimizer_state.safetensors" OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json" SCHEDULER_STATE = "scheduler_state.json" -# cache dir -default_cache_path = Path(HF_HOME) / "lerobot" -HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser() - if "LEROBOT_HOME" in os.environ: raise ValueError( f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n" "'LEROBOT_HOME' is deprecated, please use 'HF_LEROBOT_HOME' instead." ) + +# cache dir +default_cache_path = Path(HF_HOME) / "lerobot" +HF_LEROBOT_HOME = Path(os.getenv("HF_LEROBOT_HOME", default_cache_path)).expanduser() + +# calibration dir +default_calibration_path = HF_LEROBOT_HOME / "calibration" +HF_LEROBOT_CALIBRATION = Path(os.getenv("HF_LEROBOT_CALIBRATION", default_calibration_path)).expanduser() diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index 38c01b42f..88d3f767f 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -49,7 +49,7 @@ def resolve_delta_timestamps( "observation.state": [-0.04, -0.02, 0] "observation.action": [-0.02, 0, 0.02] } - returns `None` if the the resulting dict is empty. + returns `None` if the resulting dict is empty. """ delta_timestamps = {} for key in ds_meta.features: diff --git a/lerobot/common/datasets/image_writer.py b/lerobot/common/datasets/image_writer.py index 6fc0ee2f8..4a4e1ab05 100644 --- a/lerobot/common/datasets/image_writer.py +++ b/lerobot/common/datasets/image_writer.py @@ -106,7 +106,7 @@ def worker_process(queue: queue.Queue, num_threads: int): class AsyncImageWriter: """ This class abstract away the initialisation of processes or/and threads to - save images on disk asynchrounously, which is critical to control a robot and record data + save images on disk asynchronously, which is critical to control a robot and record data at a high frame rate. When `num_processes=0`, it creates a threads pool of size `num_threads`. diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py index 720c939b8..3ac1d5771 100644 --- a/lerobot/common/datasets/transforms.py +++ b/lerobot/common/datasets/transforms.py @@ -128,7 +128,7 @@ class SharpnessJitter(Transform): raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.") if not 0.0 <= sharpness[0] <= sharpness[1]: - raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.") + raise ValueError(f"sharpness values should be between (0., inf), but got {sharpness}.") return float(sharpness[0]), float(sharpness[1]) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index 48c3daccf..8a631168e 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -45,7 +45,7 @@ from lerobot.common.datasets.backward_compatibility import ( BackwardCompatibilityError, ForwardCompatibilityError, ) -from lerobot.common.robot_devices.robots.utils import Robot +from lerobot.common.robots import Robot from lerobot.common.utils.utils import is_valid_numpy_dtype_string from lerobot.configs.types import FeatureType, PolicyFeature @@ -468,6 +468,59 @@ def get_hf_features_from_features(features: dict) -> datasets.Features: return datasets.Features(hf_features) +def _validate_feature_names(features: dict[str, dict]) -> None: + invalid_features = {name: ft for name, ft in features.items() if "/" in name} + if invalid_features: + raise ValueError(f"Feature names should not contain '/'. Found '/' in '{invalid_features}'.") + + +def hw_to_dataset_features( + hw_features: dict[str, type | tuple], prefix: str, use_video: bool = True +) -> dict[str, dict]: + features = {} + joint_fts = {key: ftype for key, ftype in hw_features.items() if ftype is float} + cam_fts = {key: shape for key, shape in hw_features.items() if isinstance(shape, tuple)} + + if joint_fts and prefix == "action": + features[prefix] = { + "dtype": "float32", + "shape": (len(joint_fts),), + "names": list(joint_fts), + } + + if joint_fts and prefix == "observation": + features[f"{prefix}.state"] = { + "dtype": "float32", + "shape": (len(joint_fts),), + "names": list(joint_fts), + } + + for key, shape in cam_fts.items(): + features[f"{prefix}.images.{key}"] = { + "dtype": "video" if use_video else "image", + "shape": shape, + "names": ["height", "width", "channels"], + } + + _validate_feature_names(features) + return features + + +def build_dataset_frame( + ds_features: dict[str, dict], values: dict[str, Any], prefix: str +) -> dict[str, np.ndarray]: + frame = {} + for key, ft in ds_features.items(): + if key in DEFAULT_FEATURES or not key.startswith(prefix): + continue + elif ft["dtype"] == "float32" and len(ft["shape"]) == 1: + frame[key] = np.array([values[name] for name in ft["names"]], dtype=np.float32) + elif ft["dtype"] in ["image", "video"]: + frame[key] = values[key.removeprefix(f"{prefix}.images.")] + + return frame + + def get_features_from_robot(robot: Robot, use_videos: bool = True) -> dict: # TODO(rcadene): add fps for each feature camera_ft = {} @@ -497,7 +550,7 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea type = FeatureType.ENV elif key.startswith("observation"): type = FeatureType.STATE - elif key == "action": + elif key.startswith("action"): type = FeatureType.ACTION else: continue @@ -513,9 +566,9 @@ def dataset_to_policy_features(features: dict[str, dict]) -> dict[str, PolicyFea def create_empty_dataset_info( codebase_version: str, fps: int, - robot_type: str, features: dict, use_videos: bool, + robot_type: str | None = None, ) -> dict: return { "codebase_version": codebase_version, @@ -767,16 +820,12 @@ class IterableNamespace(SimpleNamespace): def validate_frame(frame: dict, features: dict): - optional_features = {"timestamp"} - expected_features = (set(features) - set(DEFAULT_FEATURES.keys())) | {"task"} - actual_features = set(frame.keys()) + expected_features = set(features) - set(DEFAULT_FEATURES) + actual_features = set(frame) - error_message = validate_features_presence(actual_features, expected_features, optional_features) + error_message = validate_features_presence(actual_features, expected_features) - if "task" in frame: - error_message += validate_feature_string("task", frame["task"]) - - common_features = actual_features & (expected_features | optional_features) + common_features = actual_features & expected_features for name in common_features - {"task"}: error_message += validate_feature_dtype_and_shape(name, features[name], frame[name]) @@ -784,12 +833,10 @@ def validate_frame(frame: dict, features: dict): raise ValueError(error_message) -def validate_features_presence( - actual_features: set[str], expected_features: set[str], optional_features: set[str] -): +def validate_features_presence(actual_features: set[str], expected_features: set[str]): error_message = "" missing_features = expected_features - actual_features - extra_features = actual_features - (expected_features | optional_features) + extra_features = actual_features - expected_features if missing_features or extra_features: error_message += "Feature mismatch in `frame` dictionary:\n" diff --git a/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py index 99ab2cbf6..9b21cf7ca 100644 --- a/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py +++ b/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py @@ -27,7 +27,7 @@ from textwrap import dedent from lerobot import available_datasets from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset -from lerobot.common.robot_devices.robots.configs import AlohaRobotConfig +from lerobot.common.robots.aloha.configuration_aloha import AlohaRobotConfig LOCAL_DIR = Path("data/") @@ -36,7 +36,7 @@ ALOHA_MOBILE_INFO = { "robot_config": AlohaRobotConfig(), "license": "mit", "url": "https://mobile-aloha.github.io/", - "paper": "https://arxiv.org/abs/2401.02117", + "paper": "https://huggingface.co/papers/2401.02117", "citation_bibtex": dedent(r""" @inproceedings{fu2024mobile, author = {Fu, Zipeng and Zhao, Tony Z. and Finn, Chelsea}, @@ -49,7 +49,7 @@ ALOHA_STATIC_INFO = { "robot_config": AlohaRobotConfig(), "license": "mit", "url": "https://tonyzhaozh.github.io/aloha/", - "paper": "https://arxiv.org/abs/2304.13705", + "paper": "https://huggingface.co/papers/2304.13705", "citation_bibtex": dedent(r""" @article{Zhao2023LearningFB, title={Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware}, @@ -57,13 +57,13 @@ ALOHA_STATIC_INFO = { journal={RSS}, year={2023}, volume={abs/2304.13705}, - url={https://arxiv.org/abs/2304.13705} + url={https://huggingface.co/papers/2304.13705} }""").lstrip(), } PUSHT_INFO = { "license": "mit", "url": "https://diffusion-policy.cs.columbia.edu/", - "paper": "https://arxiv.org/abs/2303.04137v5", + "paper": "https://huggingface.co/papers/2303.04137", "citation_bibtex": dedent(r""" @article{chi2024diffusionpolicy, author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song}, @@ -75,7 +75,7 @@ PUSHT_INFO = { XARM_INFO = { "license": "mit", "url": "https://www.nicklashansen.com/td-mpc/", - "paper": "https://arxiv.org/abs/2203.04955", + "paper": "https://huggingface.co/papers/2203.04955", "citation_bibtex": dedent(r""" @inproceedings{Hansen2022tdmpc, title={Temporal Difference Learning for Model Predictive Control}, @@ -244,7 +244,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://ut-austin-rpl.github.io/BUDS-website/", - "paper": "https://arxiv.org/abs/2109.13841", + "paper": "https://huggingface.co/papers/2109.13841", "citation_bibtex": dedent(r""" @article{zhu2022bottom, title={Bottom-Up Skill Discovery From Unsegmented Demonstrations for Long-Horizon Robot Manipulation}, @@ -261,7 +261,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://ut-austin-rpl.github.io/sailor/", - "paper": "https://arxiv.org/abs/2210.11435", + "paper": "https://huggingface.co/papers/2210.11435", "citation_bibtex": dedent(r""" @inproceedings{nasiriany2022sailor, title={Learning and Retrieval from Prior Data for Skill-based Imitation Learning}, @@ -274,7 +274,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://ut-austin-rpl.github.io/sirius/", - "paper": "https://arxiv.org/abs/2211.08416", + "paper": "https://huggingface.co/papers/2211.08416", "citation_bibtex": dedent(r""" @inproceedings{liu2022robot, title = {Robot Learning on the Job: Human-in-the-Loop Autonomy and Learning During Deployment}, @@ -298,14 +298,14 @@ DATASETS = { "tasks_col": "language_instruction", "license": "cc-by-4.0", "url": "https://sites.google.com/view/cablerouting/home", - "paper": "https://arxiv.org/abs/2307.08927", + "paper": "https://huggingface.co/papers/2307.08927", "citation_bibtex": dedent(r""" @article{luo2023multistage, author = {Jianlan Luo and Charles Xu and Xinyang Geng and Gilbert Feng and Kuan Fang and Liam Tan and Stefan Schaal and Sergey Levine}, title = {Multi-Stage Cable Routing through Hierarchical Imitation Learning}, journal = {arXiv pre-print}, year = {2023}, - url = {https://arxiv.org/abs/2307.08927}, + url = {https://huggingface.co/papers/2307.08927}, }""").lstrip(), }, "berkeley_fanuc_manipulation": { @@ -322,7 +322,7 @@ DATASETS = { "berkeley_gnm_cory_hall": { "tasks_col": "language_instruction", "license": "mit", - "paper": "https://arxiv.org/abs/1709.10489", + "paper": "https://huggingface.co/papers/1709.10489", "citation_bibtex": dedent(r""" @inproceedings{kahn2018self, title={Self-supervised deep reinforcement learning with generalized computation graphs for robot navigation}, @@ -337,7 +337,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://sites.google.com/view/recon-robot", - "paper": "https://arxiv.org/abs/2104.05859", + "paper": "https://huggingface.co/papers/2104.05859", "citation_bibtex": dedent(r""" @inproceedings{shah2021rapid, title={Rapid Exploration for Open-World Navigation with Latent Goal Models}, @@ -351,7 +351,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://sites.google.com/view/SACSoN-review", - "paper": "https://arxiv.org/abs/2306.01874", + "paper": "https://huggingface.co/papers/2306.01874", "citation_bibtex": dedent(r""" @article{hirose2023sacson, title={SACSoN: Scalable Autonomous Data Collection for Social Navigation}, @@ -363,7 +363,7 @@ DATASETS = { "berkeley_mvp": { "tasks_col": "language_instruction", "license": "mit", - "paper": "https://arxiv.org/abs/2203.06173", + "paper": "https://huggingface.co/papers/2203.06173", "citation_bibtex": dedent(r""" @InProceedings{Radosavovic2022, title = {Real-World Robot Learning with Masked Visual Pre-training}, @@ -375,7 +375,7 @@ DATASETS = { "berkeley_rpt": { "tasks_col": "language_instruction", "license": "mit", - "paper": "https://arxiv.org/abs/2306.10007", + "paper": "https://huggingface.co/papers/2306.10007", "citation_bibtex": dedent(r""" @article{Radosavovic2023, title={Robot Learning with Sensorimotor Pre-training}, @@ -388,7 +388,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://human-world-model.github.io/", - "paper": "https://arxiv.org/abs/2308.10901", + "paper": "https://huggingface.co/papers/2308.10901", "citation_bibtex": dedent(r""" @inproceedings{mendonca2023structured, title={Structured World Models from Human Videos}, @@ -401,7 +401,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://play-fusion.github.io/", - "paper": "https://arxiv.org/abs/2312.04549", + "paper": "https://huggingface.co/papers/2312.04549", "citation_bibtex": dedent(r""" @inproceedings{chen2023playfusion, title={PlayFusion: Skill Acquisition via Diffusion from Language-Annotated Play}, @@ -414,7 +414,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://robo-affordances.github.io/", - "paper": "https://arxiv.org/abs/2304.08488", + "paper": "https://huggingface.co/papers/2304.08488", "citation_bibtex": dedent(r""" @inproceedings{bahl2023affordances, title={Affordances from Human Videos as a Versatile Representation for Robotics}, @@ -433,7 +433,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://diffusion-policy.cs.columbia.edu/", - "paper": "https://arxiv.org/abs/2303.04137v5", + "paper": "https://huggingface.co/papers/2303.04137", "citation_bibtex": dedent(r""" @inproceedings{chi2023diffusionpolicy, title={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion}, @@ -505,7 +505,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://droid-dataset.github.io/", - "paper": "https://arxiv.org/abs/2403.12945", + "paper": "https://huggingface.co/papers/2403.12945", "citation_bibtex": dedent(r""" @article{khazatsky2024droid, title = {DROID: A Large-Scale In-The-Wild Robot Manipulation Dataset}, @@ -517,7 +517,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "cc-by-4.0", "url": "https://functional-manipulation-benchmark.github.io/", - "paper": "https://arxiv.org/abs/2401.08553", + "paper": "https://huggingface.co/papers/2401.08553", "citation_bibtex": dedent(r""" @article{luo2024fmb, title={FMB: a Functional Manipulation Benchmark for Generalizable Robotic Learning}, @@ -530,7 +530,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://openreview.net/forum?id=WuBv9-IGDUA", - "paper": "https://arxiv.org/abs/2401.14502", + "paper": "https://huggingface.co/papers/2401.14502", "citation_bibtex": dedent(r""" @inproceedings{saxena2023multiresolution, title={Multi-Resolution Sensing for Real-Time Control with Vision-Language Models}, @@ -575,7 +575,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://jyopari.github.io/VINN/", - "paper": "https://arxiv.org/abs/2112.01511", + "paper": "https://huggingface.co/papers/2112.01511", "citation_bibtex": dedent(r""" @misc{pari2021surprising, title={The Surprising Effectiveness of Representation Learning for Visual Imitation}, @@ -590,7 +590,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://play-to-policy.github.io/", - "paper": "https://arxiv.org/abs/2210.10047", + "paper": "https://huggingface.co/papers/2210.10047", "citation_bibtex": dedent(r""" @article{cui2022play, title = {From Play to Policy: Conditional Behavior Generation from Uncurated Robot Data}, @@ -603,7 +603,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://rot-robot.github.io/", - "paper": "https://arxiv.org/abs/2206.15469", + "paper": "https://huggingface.co/papers/2206.15469", "citation_bibtex": dedent(r""" @inproceedings{haldar2023watch, title={Watch and match: Supercharging imitation with regularized optimal transport}, @@ -633,7 +633,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://sites.google.com/view/hydra-il-2023", - "paper": "https://arxiv.org/abs/2306.17237", + "paper": "https://huggingface.co/papers/2306.17237", "citation_bibtex": dedent(r""" @article{belkhale2023hydra, title={HYDRA: Hybrid Robot Actions for Imitation Learning}, @@ -646,21 +646,21 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://sites.google.com/view/visionandtouch", - "paper": "https://arxiv.org/abs/1810.10191", + "paper": "https://huggingface.co/papers/1810.10191", "citation_bibtex": dedent(r""" @inproceedings{lee2019icra, title={Making sense of vision and touch: Self-supervised learning of multimodal representations for contact-rich tasks}, author={Lee, Michelle A and Zhu, Yuke and Srinivasan, Krishnan and Shah, Parth and Savarese, Silvio and Fei-Fei, Li and Garg, Animesh and Bohg, Jeannette}, booktitle={2019 IEEE International Conference on Robotics and Automation (ICRA)}, year={2019}, - url={https://arxiv.org/abs/1810.10191} + url={https://huggingface.co/papers/1810.10191} }""").lstrip(), }, "stanford_robocook": { "tasks_col": "language_instruction", "license": "mit", "url": "https://hshi74.github.io/robocook/", - "paper": "https://arxiv.org/abs/2306.14447", + "paper": "https://huggingface.co/papers/2306.14447", "citation_bibtex": dedent(r""" @article{shi2023robocook, title={RoboCook: Long-Horizon Elasto-Plastic Object Manipulation with Diverse Tools}, @@ -673,7 +673,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "cc-by-4.0", "url": "https://www.kaggle.com/datasets/oiermees/taco-robot", - "paper": "https://arxiv.org/abs/2209.08959, https://arxiv.org/abs/2210.01911", + "paper": "https://huggingface.co/papers/2209.08959, https://huggingface.co/papers/2210.01911", "citation_bibtex": dedent(r""" @inproceedings{rosete2022tacorl, author = {Erick Rosete-Beas and Oier Mees and Gabriel Kalweit and Joschka Boedecker and Wolfram Burgard}, @@ -693,7 +693,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "URL", - "paper": "https://arxiv.org/abs/2107.05842", + "paper": "https://huggingface.co/papers/2107.05842", "citation_bibtex": dedent(r""" @Article{Osa22, author = {Takayuki Osa}, @@ -709,7 +709,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://toto-benchmark.org/", - "paper": "https://arxiv.org/abs/2306.00942", + "paper": "https://huggingface.co/papers/2306.00942", "citation_bibtex": dedent(r""" @inproceedings{zhou2023train, author={Zhou, Gaoyue and Dean, Victoria and Srirama, Mohan Kumar and Rajeswaran, Aravind and Pari, Jyothish and Hatch, Kyle and Jain, Aryan and Yu, Tianhe and Abbeel, Pieter and Pinto, Lerrel and Finn, Chelsea and Gupta, Abhinav}, @@ -733,7 +733,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://owmcorl.github.io/#", - "paper": "https://arxiv.org/abs/2310.16029", + "paper": "https://huggingface.co/papers/2310.16029", "citation_bibtex": dedent(r""" @preprint{Feng2023Finetuning, title={Finetuning Offline World Models in the Real World}, @@ -745,7 +745,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://robopil.github.io/d3fields/", - "paper": "https://arxiv.org/abs/2309.16118", + "paper": "https://huggingface.co/papers/2309.16118", "citation_bibtex": dedent(r""" @article{wang2023d3field, title={D^3Field: Dynamic 3D Descriptor Fields for Generalizable Robotic Manipulation}, @@ -758,7 +758,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://uscresl.github.io/dmfd/", - "paper": "https://arxiv.org/abs/2207.10148", + "paper": "https://huggingface.co/papers/2207.10148", "citation_bibtex": dedent(r""" @article{salhotra2022dmfd, author={Salhotra, Gautam and Liu, I-Chun Arthur and Dominguez-Kuhne, Marcus and Sukhatme, Gaurav S.}, @@ -775,7 +775,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://ut-austin-rpl.github.io/MUTEX/", - "paper": "https://arxiv.org/abs/2309.14320", + "paper": "https://huggingface.co/papers/2309.14320", "citation_bibtex": dedent(r""" @inproceedings{shah2023mutex, title={{MUTEX}: Learning Unified Policies from Multimodal Task Specifications}, @@ -811,7 +811,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://saytap.github.io/", - "paper": "https://arxiv.org/abs/2306.07580", + "paper": "https://huggingface.co/papers/2306.07580", "citation_bibtex": dedent(r""" @article{saytap2023, author = {Yujin Tang and Wenhao Yu and Jie Tan and Heiga Zen and Aleksandra Faust and @@ -847,7 +847,7 @@ DATASETS = { "tasks_col": "language_instruction", "license": "mit", "url": "https://ut-austin-rpl.github.io/VIOLA/", - "paper": "https://arxiv.org/abs/2210.11339", + "paper": "https://huggingface.co/papers/2210.11339", "citation_bibtex": dedent(r""" @article{zhu2022viola, title={VIOLA: Imitation Learning for Vision-Based Manipulation with Object Proposal Priors}, diff --git a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py index c761f38b4..c6ef6112a 100644 --- a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py +++ b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py @@ -141,8 +141,7 @@ from lerobot.common.datasets.video_utils import ( get_image_pixel_channels, get_video_info, ) -from lerobot.common.robot_devices.robots.configs import RobotConfig -from lerobot.common.robot_devices.robots.utils import make_robot_config +from lerobot.common.robots import RobotConfig V16 = "v1.6" V20 = "v2.0" @@ -596,6 +595,30 @@ def convert_dataset( create_branch(repo_id=repo_id, branch=V20, repo_type="dataset") +def make_robot_config(robot_type: str, **kwargs) -> RobotConfig: + if robot_type == "aloha": + raise NotImplementedError # TODO + + elif robot_type == "koch_follower": + from lerobot.common.robots.koch_follower import KochFollowerConfig + + return KochFollowerConfig(**kwargs) + elif robot_type == "so100_follower": + from lerobot.common.robots.so100_follower import SO100FollowerConfig + + return SO100FollowerConfig(**kwargs) + elif robot_type == "stretch": + from lerobot.common.robots.stretch3 import Stretch3RobotConfig + + return Stretch3RobotConfig(**kwargs) + elif robot_type == "lekiwi": + from lerobot.common.robots.lekiwi import LeKiwiConfig + + return LeKiwiConfig(**kwargs) + else: + raise ValueError(f"Robot type '{robot_type}' is not available.") + + def main(): parser = argparse.ArgumentParser() task_args = parser.add_mutually_exclusive_group(required=True) diff --git a/lerobot/common/envs/configs.py b/lerobot/common/envs/configs.py index cf90048a3..ea081e9fb 100644 --- a/lerobot/common/envs/configs.py +++ b/lerobot/common/envs/configs.py @@ -14,10 +14,13 @@ import abc from dataclasses import dataclass, field +from typing import Any, Optional import draccus -from lerobot.common.constants import ACTION, OBS_ENV, OBS_IMAGE, OBS_IMAGES, OBS_ROBOT +from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.common.robots import RobotConfig +from lerobot.common.teleoperators.config import TeleoperatorConfig from lerobot.configs.types import FeatureType, PolicyFeature @@ -32,7 +35,8 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC): def type(self) -> str: return self.get_choice_name(self.__class__) - @abc.abstractproperty + @property + @abc.abstractmethod def gym_kwargs(self) -> dict: raise NotImplementedError() @@ -53,7 +57,7 @@ class AlohaEnv(EnvConfig): features_map: dict[str, str] = field( default_factory=lambda: { "action": ACTION, - "agent_pos": OBS_ROBOT, + "agent_pos": OBS_STATE, "top": f"{OBS_IMAGE}.top", "pixels/top": f"{OBS_IMAGES}.top", } @@ -94,8 +98,8 @@ class PushtEnv(EnvConfig): features_map: dict[str, str] = field( default_factory=lambda: { "action": ACTION, - "agent_pos": OBS_ROBOT, - "environment_state": OBS_ENV, + "agent_pos": OBS_STATE, + "environment_state": OBS_ENV_STATE, "pixels": OBS_IMAGE, } ) @@ -136,7 +140,7 @@ class XarmEnv(EnvConfig): features_map: dict[str, str] = field( default_factory=lambda: { "action": ACTION, - "agent_pos": OBS_ROBOT, + "agent_pos": OBS_STATE, "pixels": OBS_IMAGE, } ) @@ -154,3 +158,116 @@ class XarmEnv(EnvConfig): "visualization_height": self.visualization_height, "max_episode_steps": self.episode_length, } + + +@dataclass +class VideoRecordConfig: + """Configuration for video recording in ManiSkill environments.""" + + enabled: bool = False + record_dir: str = "videos" + trajectory_name: str = "trajectory" + + +@dataclass +class EnvTransformConfig: + """Configuration for environment wrappers.""" + + # ee_action_space_params: EEActionSpaceConfig = field(default_factory=EEActionSpaceConfig) + control_mode: str = "gamepad" + display_cameras: bool = False + add_joint_velocity_to_observation: bool = False + add_current_to_observation: bool = False + add_ee_pose_to_observation: bool = False + crop_params_dict: Optional[dict[str, tuple[int, int, int, int]]] = None + resize_size: Optional[tuple[int, int]] = None + control_time_s: float = 20.0 + fixed_reset_joint_positions: Optional[Any] = None + reset_time_s: float = 5.0 + use_gripper: bool = True + gripper_quantization_threshold: float | None = 0.8 + gripper_penalty: float = 0.0 + gripper_penalty_in_reward: bool = False + + +@EnvConfig.register_subclass(name="gym_manipulator") +@dataclass +class HILSerlRobotEnvConfig(EnvConfig): + """Configuration for the HILSerlRobotEnv environment.""" + + robot: Optional[RobotConfig] = None + teleop: Optional[TeleoperatorConfig] = None + wrapper: Optional[EnvTransformConfig] = None + fps: int = 10 + name: str = "real_robot" + mode: str = None # Either "record", "replay", None + repo_id: Optional[str] = None + dataset_root: Optional[str] = None + task: str = "" + num_episodes: int = 10 # only for record mode + episode: int = 0 + device: str = "cuda" + push_to_hub: bool = True + pretrained_policy_name_or_path: Optional[str] = None + reward_classifier_pretrained_path: Optional[str] = None + # For the reward classifier, to record more positive examples after a success + number_of_steps_after_success: int = 0 + + def gym_kwargs(self) -> dict: + return {} + + +@EnvConfig.register_subclass("hil") +@dataclass +class HILEnvConfig(EnvConfig): + """Configuration for the HIL environment.""" + + type: str = "hil" + name: str = "PandaPickCube" + task: str = "PandaPickCubeKeyboard-v0" + use_viewer: bool = True + gripper_penalty: float = 0.0 + use_gamepad: bool = True + state_dim: int = 18 + action_dim: int = 4 + fps: int = 100 + episode_length: int = 100 + video_record: VideoRecordConfig = field(default_factory=VideoRecordConfig) + features: dict[str, PolicyFeature] = field( + default_factory=lambda: { + "action": PolicyFeature(type=FeatureType.ACTION, shape=(4,)), + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 128, 128)), + "observation.state": PolicyFeature(type=FeatureType.STATE, shape=(18,)), + } + ) + features_map: dict[str, str] = field( + default_factory=lambda: { + "action": ACTION, + "observation.image": OBS_IMAGE, + "observation.state": OBS_STATE, + } + ) + ################# args from hilserlrobotenv + reward_classifier_pretrained_path: Optional[str] = None + robot_config: Optional[RobotConfig] = None + teleop_config: Optional[TeleoperatorConfig] = None + wrapper: Optional[EnvTransformConfig] = None + mode: str = None # Either "record", "replay", None + repo_id: Optional[str] = None + dataset_root: Optional[str] = None + num_episodes: int = 10 # only for record mode + episode: int = 0 + device: str = "cuda" + push_to_hub: bool = True + pretrained_policy_name_or_path: Optional[str] = None + # For the reward classifier, to record more positive examples after a success + number_of_steps_after_success: int = 0 + ############################ + + @property + def gym_kwargs(self) -> dict: + return { + "use_viewer": self.use_viewer, + "use_gamepad": self.use_gamepad, + "gripper_penalty": self.gripper_penalty, + } diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 8450f84b9..4f5d59c69 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -17,7 +17,7 @@ import importlib import gymnasium as gym -from lerobot.common.envs.configs import AlohaEnv, EnvConfig, PushtEnv, XarmEnv +from lerobot.common.envs.configs import AlohaEnv, EnvConfig, HILEnvConfig, PushtEnv, XarmEnv def make_env_config(env_type: str, **kwargs) -> EnvConfig: @@ -27,6 +27,8 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig: return PushtEnv(**kwargs) elif env_type == "xarm": return XarmEnv(**kwargs) + elif env_type == "hil": + return HILEnvConfig(**kwargs) else: raise ValueError(f"Policy type '{env_type}' is not available.") diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index 83334f876..66d6e5f93 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -47,6 +47,10 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten # TODO(aliberts, rcadene): use transforms.ToTensor()? img = torch.from_numpy(img) + # When preprocessing observations in a non-vectorized environment, we need to add a batch dimension. + # This is the case for human-in-the-loop RL where there is only one environment. + if img.ndim == 3: + img = img.unsqueeze(0) # sanity check that images are channel last _, h, w, c = img.shape assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" @@ -62,13 +66,18 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten return_observations[imgkey] = img if "environment_state" in observations: - return_observations["observation.environment_state"] = torch.from_numpy( - observations["environment_state"] - ).float() + env_state = torch.from_numpy(observations["environment_state"]).float() + if env_state.dim() == 1: + env_state = env_state.unsqueeze(0) + + return_observations["observation.environment_state"] = env_state # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing - # requirement for "agent_pos" - return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float() + agent_pos = torch.from_numpy(observations["agent_pos"]).float() + if agent_pos.dim() == 1: + agent_pos = agent_pos.unsqueeze(0) + return_observations["observation.state"] = agent_pos + return return_observations diff --git a/lerobot/common/errors.py b/lerobot/common/errors.py new file mode 100644 index 000000000..c02d568d4 --- /dev/null +++ b/lerobot/common/errors.py @@ -0,0 +1,43 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class DeviceNotConnectedError(ConnectionError): + """Exception raised when the device is not connected.""" + + def __init__(self, message="This device is not connected. Try calling `connect()` first."): + self.message = message + super().__init__(self.message) + + +class DeviceAlreadyConnectedError(ConnectionError): + """Exception raised when the device is already connected.""" + + def __init__( + self, + message="This device is already connected. Try not calling `connect()` twice.", + ): + self.message = message + super().__init__(self.message) + + +class InvalidActionError(ValueError): + """Exception raised when an action is already invalid.""" + + def __init__( + self, + message="The action is invalid. Check the value follows what it is expected from the action space.", + ): + self.message = message + super().__init__(self.message) diff --git a/lerobot/common/model/kinematics.py b/lerobot/common/model/kinematics.py new file mode 100644 index 000000000..367b609e1 --- /dev/null +++ b/lerobot/common/model/kinematics.py @@ -0,0 +1,483 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import numpy as np +from numpy.typing import NDArray +from scipy.spatial.transform import Rotation + + +def skew_symmetric(w: NDArray[np.float32]) -> NDArray[np.float32]: + """Creates the skew-symmetric matrix from a 3D vector.""" + return np.array([[0, -w[2], w[1]], [w[2], 0, -w[0]], [-w[1], w[0], 0]]) + + +def rodrigues_rotation(w: NDArray[np.float32], theta: float) -> NDArray[np.float32]: + """Computes the rotation matrix using Rodrigues' formula.""" + w_hat = skew_symmetric(w) + return np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat + + +def screw_axis_to_transform(s: NDArray[np.float32], theta: float) -> NDArray[np.float32]: + """Converts a screw axis to a 4x4 transformation matrix.""" + screw_axis_rot = s[:3] + screw_axis_trans = s[3:] + + # Pure translation + if np.allclose(screw_axis_rot, 0) and np.linalg.norm(screw_axis_trans) == 1: + transform = np.eye(4) + transform[:3, 3] = screw_axis_trans * theta + + # Rotation (and potentially translation) + elif np.linalg.norm(screw_axis_rot) == 1: + w_hat = skew_symmetric(screw_axis_rot) + rot_mat = np.eye(3) + np.sin(theta) * w_hat + (1 - np.cos(theta)) * w_hat @ w_hat + t = ( + np.eye(3) * theta + (1 - np.cos(theta)) * w_hat + (theta - np.sin(theta)) * w_hat @ w_hat + ) @ screw_axis_trans + transform = np.eye(4) + transform[:3, :3] = rot_mat + transform[:3, 3] = t + else: + raise ValueError("Invalid screw axis parameters") + return transform + + +def pose_difference_se3(pose1: NDArray[np.float32], pose2: NDArray[np.float32]) -> NDArray[np.float32]: + """ + Calculates the SE(3) difference between two 4x4 homogeneous transformation matrices. + SE(3) (Special Euclidean Group) represents rigid body transformations in 3D space, + combining rotation (SO(3)) and translation. + + Each 4x4 matrix has the following structure: + [R11 R12 R13 tx] + [R21 R22 R23 ty] + [R31 R32 R33 tz] + [ 0 0 0 1] + + where R is the 3x3 rotation matrix and [tx,ty,tz] is the translation vector. + + Args: + pose1: A 4x4 numpy array representing the first pose. + pose2: A 4x4 numpy array representing the second pose. + + Returns: + A 6D numpy array concatenating translation and rotation differences. + First 3 elements are the translational difference (position). + Last 3 elements are the rotational difference in axis-angle representation. + """ + rot1 = pose1[:3, :3] + rot2 = pose2[:3, :3] + + translation_diff = pose1[:3, 3] - pose2[:3, 3] + + # Calculate rotational difference using scipy's Rotation library + rot_diff = Rotation.from_matrix(rot1 @ rot2.T) + rotation_diff = rot_diff.as_rotvec() # Axis-angle representation + + return np.concatenate([translation_diff, rotation_diff]) + + +def se3_error(target_pose: NDArray[np.float32], current_pose: NDArray[np.float32]) -> NDArray[np.float32]: + pos_error = target_pose[:3, 3] - current_pose[:3, 3] + + rot_target = target_pose[:3, :3] + rot_current = current_pose[:3, :3] + rot_error_mat = rot_target @ rot_current.T + rot_error = Rotation.from_matrix(rot_error_mat).as_rotvec() + + return np.concatenate([pos_error, rot_error]) + + +class RobotKinematics: + """Robot kinematics class supporting multiple robot models.""" + + # Robot measurements dictionary + ROBOT_MEASUREMENTS = { + "koch": { + "gripper": [0.239, -0.001, 0.024], + "wrist": [0.209, 0, 0.024], + "forearm": [0.108, 0, 0.02], + "humerus": [0, 0, 0.036], + "shoulder": [0, 0, 0], + "base": [0, 0, 0.02], + }, + "moss": { + "gripper": [0.246, 0.013, 0.111], + "wrist": [0.245, 0.002, 0.064], + "forearm": [0.122, 0, 0.064], + "humerus": [0.001, 0.001, 0.063], + "shoulder": [0, 0, 0], + "base": [0, 0, 0.02], + }, + "so_old_calibration": { + "gripper": [0.320, 0, 0.050], + "wrist": [0.278, 0, 0.050], + "forearm": [0.143, 0, 0.044], + "humerus": [0.031, 0, 0.072], + "shoulder": [0, 0, 0], + "base": [0, 0, 0.02], + }, + "so_new_calibration": { + "gripper": [0.33, 0.0, 0.285], + "wrist": [0.30, 0.0, 0.267], + "forearm": [0.25, 0.0, 0.266], + "humerus": [0.06, 0.0, 0.264], + "shoulder": [0.0, 0.0, 0.238], + "base": [0.0, 0.0, 0.12], + }, + } + + def __init__(self, robot_type: str = "so100"): + """Initialize kinematics for the specified robot type. + + Args: + robot_type: String specifying the robot model ("koch", "so100", or "moss") + """ + if robot_type not in self.ROBOT_MEASUREMENTS: + raise ValueError( + f"Unknown robot type: {robot_type}. Available types: {list(self.ROBOT_MEASUREMENTS.keys())}" + ) + + self.robot_type = robot_type + self.measurements = self.ROBOT_MEASUREMENTS[robot_type] + + # Initialize all transformation matrices and screw axes + self._setup_transforms() + + def _create_translation_matrix( + self, x: float = 0.0, y: float = 0.0, z: float = 0.0 + ) -> NDArray[np.float32]: + """Create a 4x4 translation matrix.""" + return np.array([[1, 0, 0, x], [0, 1, 0, y], [0, 0, 1, z], [0, 0, 0, 1]]) + + def _setup_transforms(self): + """Setup all transformation matrices and screw axes for the robot.""" + # Set up rotation matrices (constant across robot types) + + # Gripper orientation + self.gripper_X0 = np.array( + [ + [1, 0, 0, 0], + [0, 0, 1, 0], + [0, -1, 0, 0], + [0, 0, 0, 1], + ], + dtype=np.float32, + ) + + # Wrist orientation + self.wrist_X0 = np.array( + [ + [0, -1, 0, 0], + [1, 0, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + ], + dtype=np.float32, + ) + + # Base orientation + self.base_X0 = np.array( + [ + [0, 0, 1, 0], + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 0, 1], + ], + dtype=np.float32, + ) + + # Gripper + # Screw axis of gripper frame wrt base frame + self.S_BG = np.array( + [ + 1, + 0, + 0, + 0, + self.measurements["gripper"][2], + -self.measurements["gripper"][1], + ], + dtype=np.float32, + ) + + # Gripper origin to centroid transform + self.X_GoGc = self._create_translation_matrix(x=0.07) + + # Gripper origin to tip transform + self.X_GoGt = self._create_translation_matrix(x=0.12) + + # 0-position gripper frame pose wrt base + self.X_BoGo = self._create_translation_matrix( + x=self.measurements["gripper"][0], + y=self.measurements["gripper"][1], + z=self.measurements["gripper"][2], + ) + + # Wrist + # Screw axis of wrist frame wrt base frame + self.S_BR = np.array( + [0, 1, 0, -self.measurements["wrist"][2], 0, self.measurements["wrist"][0]], dtype=np.float32 + ) + + # 0-position origin to centroid transform + self.X_RoRc = self._create_translation_matrix(x=0.0035, y=-0.002) + + # 0-position wrist frame pose wrt base + self.X_BR = self._create_translation_matrix( + x=self.measurements["wrist"][0], + y=self.measurements["wrist"][1], + z=self.measurements["wrist"][2], + ) + + # Forearm + # Screw axis of forearm frame wrt base frame + self.S_BF = np.array( + [ + 0, + 1, + 0, + -self.measurements["forearm"][2], + 0, + self.measurements["forearm"][0], + ], + dtype=np.float32, + ) + + # Forearm origin + centroid transform + self.X_ForearmFc = self._create_translation_matrix(x=0.036) + + # 0-position forearm frame pose wrt base + self.X_BF = self._create_translation_matrix( + x=self.measurements["forearm"][0], + y=self.measurements["forearm"][1], + z=self.measurements["forearm"][2], + ) + + # Humerus + # Screw axis of humerus frame wrt base frame + self.S_BH = np.array( + [ + 0, + -1, + 0, + self.measurements["humerus"][2], + 0, + -self.measurements["humerus"][0], + ], + dtype=np.float32, + ) + + # Humerus origin to centroid transform + self.X_HoHc = self._create_translation_matrix(x=0.0475) + + # 0-position humerus frame pose wrt base + self.X_BH = self._create_translation_matrix( + x=self.measurements["humerus"][0], + y=self.measurements["humerus"][1], + z=self.measurements["humerus"][2], + ) + + # Shoulder + # Screw axis of shoulder frame wrt Base frame + self.S_BS = np.array([0, 0, -1, 0, 0, 0], dtype=np.float32) + + # Shoulder origin to centroid transform + self.X_SoSc = self._create_translation_matrix(x=-0.017, z=0.0235) + + # 0-position shoulder frame pose wrt base + self.X_BS = self._create_translation_matrix( + x=self.measurements["shoulder"][0], + y=self.measurements["shoulder"][1], + z=self.measurements["shoulder"][2], + ) + + # Base + # Base origin to centroid transform + self.X_BoBc = self._create_translation_matrix(y=0.015) + + # World to base transform + self.X_WoBo = self._create_translation_matrix( + x=self.measurements["base"][0], + y=self.measurements["base"][1], + z=self.measurements["base"][2], + ) + + # Pre-compute gripper post-multiplication matrix + self._fk_gripper_post = self.X_GoGc @ self.X_BoGo @ self.gripper_X0 + + def forward_kinematics( + self, + robot_pos_deg: NDArray[np.float32], + frame: str = "gripper_tip", + ) -> NDArray[np.float32]: + """Generic forward kinematics. + + Args: + robot_pos_deg: Joint positions in degrees. Can be ``None`` when + computing the *base* frame as it does not depend on joint + angles. + frame: Target frame. One of + ``{"base", "shoulder", "humerus", "forearm", "wrist", "gripper", "gripper_tip"}``. + + Returns + ------- + NDArray[np.float32] + 4×4 homogeneous transformation matrix of the requested frame + expressed in the world coordinate system. + """ + frame = frame.lower() + if frame not in { + "base", + "shoulder", + "humerus", + "forearm", + "wrist", + "gripper", + "gripper_tip", + }: + raise ValueError( + f"Unknown frame '{frame}'. Valid options are base, shoulder, humerus, forearm, wrist, gripper, gripper_tip." + ) + + # Base frame does not rely on joint angles. + if frame == "base": + return self.X_WoBo @ self.X_BoBc @ self.base_X0 + + robot_pos_rad = robot_pos_deg / 180 * np.pi + + # Extract joint angles (note the sign convention for shoulder lift). + theta_shoulder_pan = robot_pos_rad[0] + theta_shoulder_lift = -robot_pos_rad[1] + theta_elbow_flex = robot_pos_rad[2] + theta_wrist_flex = robot_pos_rad[3] + theta_wrist_roll = robot_pos_rad[4] + + # Start with the world-to-base transform; incrementally add successive links. + transformation_matrix = self.X_WoBo @ screw_axis_to_transform(self.S_BS, theta_shoulder_pan) + if frame == "shoulder": + return transformation_matrix @ self.X_SoSc @ self.X_BS + + transformation_matrix = transformation_matrix @ screw_axis_to_transform( + self.S_BH, theta_shoulder_lift + ) + if frame == "humerus": + return transformation_matrix @ self.X_HoHc @ self.X_BH + + transformation_matrix = transformation_matrix @ screw_axis_to_transform(self.S_BF, theta_elbow_flex) + if frame == "forearm": + return transformation_matrix @ self.X_ForearmFc @ self.X_BF + + transformation_matrix = transformation_matrix @ screw_axis_to_transform(self.S_BR, theta_wrist_flex) + if frame == "wrist": + return transformation_matrix @ self.X_RoRc @ self.X_BR @ self.wrist_X0 + + transformation_matrix = transformation_matrix @ screw_axis_to_transform(self.S_BG, theta_wrist_roll) + if frame == "gripper": + return transformation_matrix @ self._fk_gripper_post + else: # frame == "gripper_tip" + return transformation_matrix @ self.X_GoGt @ self.X_BoGo @ self.gripper_X0 + + def compute_jacobian( + self, robot_pos_deg: NDArray[np.float32], frame: str = "gripper_tip" + ) -> NDArray[np.float32]: + """Finite differences to compute the Jacobian. + J(i, j) represents how the ith component of the end-effector's velocity changes wrt a small change + in the jth joint's velocity. + + Args: + robot_pos_deg: Current joint positions in degrees + fk_func: Forward kinematics function to use (defaults to fk_gripper) + """ + + eps = 1e-8 + jac = np.zeros(shape=(6, 5)) + delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64) + for el_ix in range(len(robot_pos_deg[:-1])): + delta *= 0 + delta[el_ix] = eps / 2 + sdot = ( + pose_difference_se3( + self.forward_kinematics(robot_pos_deg[:-1] + delta, frame), + self.forward_kinematics(robot_pos_deg[:-1] - delta, frame), + ) + / eps + ) + jac[:, el_ix] = sdot + return jac + + def compute_positional_jacobian( + self, robot_pos_deg: NDArray[np.float32], frame: str = "gripper_tip" + ) -> NDArray[np.float32]: + """Finite differences to compute the positional Jacobian. + J(i, j) represents how the ith component of the end-effector's position changes wrt a small change + in the jth joint's velocity. + + Args: + robot_pos_deg: Current joint positions in degrees + fk_func: Forward kinematics function to use (defaults to fk_gripper) + """ + eps = 1e-8 + jac = np.zeros(shape=(3, 5)) + delta = np.zeros(len(robot_pos_deg[:-1]), dtype=np.float64) + for el_ix in range(len(robot_pos_deg[:-1])): + delta *= 0 + delta[el_ix] = eps / 2 + sdot = ( + self.forward_kinematics(robot_pos_deg[:-1] + delta, frame)[:3, 3] + - self.forward_kinematics(robot_pos_deg[:-1] - delta, frame)[:3, 3] + ) / eps + jac[:, el_ix] = sdot + return jac + + def ik( + self, + current_joint_pos: NDArray[np.float32], + desired_ee_pose: NDArray[np.float32], + position_only: bool = True, + frame: str = "gripper_tip", + max_iterations: int = 5, + learning_rate: float = 1, + ) -> NDArray[np.float32]: + """Inverse kinematics using gradient descent. + + Args: + current_joint_state: Initial joint positions in degrees + desired_ee_pose: Target end-effector pose as a 4x4 transformation matrix + position_only: If True, only match end-effector position, not orientation + frame: Target frame. One of + ``{"base", "shoulder", "humerus", "forearm", "wrist", "gripper", "gripper_tip"}``. + max_iterations: Maximum number of iterations to run + learning_rate: Learning rate for gradient descent + + Returns: + Joint positions in degrees that achieve the desired end-effector pose + """ + # Do gradient descent. + current_joint_state = current_joint_pos.copy() + for _ in range(max_iterations): + current_ee_pose = self.forward_kinematics(current_joint_state, frame) + if not position_only: + error = se3_error(desired_ee_pose, current_ee_pose) + jac = self.compute_jacobian(current_joint_state, frame) + else: + error = desired_ee_pose[:3, 3] - current_ee_pose[:3, 3] + jac = self.compute_positional_jacobian(current_joint_state, frame) + delta_angles = np.linalg.pinv(jac) @ error + current_joint_state[:-1] += learning_rate * delta_angles + + if np.linalg.norm(error) < 5e-3: + return current_joint_state + return current_joint_state diff --git a/lerobot/common/motors/__init__.py b/lerobot/common/motors/__init__.py new file mode 100644 index 000000000..dfbfbaee8 --- /dev/null +++ b/lerobot/common/motors/__init__.py @@ -0,0 +1 @@ +from .motors_bus import Motor, MotorCalibration, MotorNormMode, MotorsBus diff --git a/lerobot/common/motors/dynamixel/__init__.py b/lerobot/common/motors/dynamixel/__init__.py new file mode 100644 index 000000000..3e414557e --- /dev/null +++ b/lerobot/common/motors/dynamixel/__init__.py @@ -0,0 +1,2 @@ +from .dynamixel import DriveMode, DynamixelMotorsBus, OperatingMode, TorqueMode +from .tables import * diff --git a/lerobot/common/motors/dynamixel/dynamixel.py b/lerobot/common/motors/dynamixel/dynamixel.py new file mode 100644 index 000000000..9f0db901d --- /dev/null +++ b/lerobot/common/motors/dynamixel/dynamixel.py @@ -0,0 +1,263 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO(aliberts): Should we implement FastSyncRead/Write? +# https://github.com/ROBOTIS-GIT/DynamixelSDK/pull/643 +# https://github.com/ROBOTIS-GIT/DynamixelSDK/releases/tag/3.8.2 +# https://emanual.robotis.com/docs/en/dxl/protocol2/#fast-sync-read-0x8a +# -> Need to check compatibility across models + +import logging +from copy import deepcopy +from enum import Enum + +from lerobot.common.utils.encoding_utils import decode_twos_complement, encode_twos_complement + +from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address +from .tables import ( + AVAILABLE_BAUDRATES, + MODEL_BAUDRATE_TABLE, + MODEL_CONTROL_TABLE, + MODEL_ENCODING_TABLE, + MODEL_NUMBER_TABLE, + MODEL_RESOLUTION, +) + +PROTOCOL_VERSION = 2.0 +DEFAULT_BAUDRATE = 1_000_000 +DEFAULT_TIMEOUT_MS = 1000 + +NORMALIZED_DATA = ["Goal_Position", "Present_Position"] + +logger = logging.getLogger(__name__) + + +class OperatingMode(Enum): + # DYNAMIXEL only controls current(torque) regardless of speed and position. This mode is ideal for a + # gripper or a system that only uses current(torque) control or a system that has additional + # velocity/position controllers. + CURRENT = 0 + + # This mode controls velocity. This mode is identical to the Wheel Mode(endless) from existing DYNAMIXEL. + # This mode is ideal for wheel-type robots. + VELOCITY = 1 + + # This mode controls position. This mode is identical to the Joint Mode from existing DYNAMIXEL. Operating + # position range is limited by the Max Position Limit(48) and the Min Position Limit(52). This mode is + # ideal for articulated robots that each joint rotates less than 360 degrees. + POSITION = 3 + + # This mode controls position. This mode is identical to the Multi-turn Position Control from existing + # DYNAMIXEL. 512 turns are supported(-256[rev] ~ 256[rev]). This mode is ideal for multi-turn wrists or + # conveyer systems or a system that requires an additional reduction gear. Note that Max Position + # Limit(48), Min Position Limit(52) are not used on Extended Position Control Mode. + EXTENDED_POSITION = 4 + + # This mode controls both position and current(torque). Up to 512 turns are supported (-256[rev] ~ + # 256[rev]). This mode is ideal for a system that requires both position and current control such as + # articulated robots or grippers. + CURRENT_POSITION = 5 + + # This mode directly controls PWM output. (Voltage Control Mode) + PWM = 16 + + +class DriveMode(Enum): + NON_INVERTED = 0 + INVERTED = 1 + + +class TorqueMode(Enum): + ENABLED = 1 + DISABLED = 0 + + +def _split_into_byte_chunks(value: int, length: int) -> list[int]: + import dynamixel_sdk as dxl + + if length == 1: + data = [value] + elif length == 2: + data = [dxl.DXL_LOBYTE(value), dxl.DXL_HIBYTE(value)] + elif length == 4: + data = [ + dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)), + dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)), + dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)), + dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)), + ] + return data + + +class DynamixelMotorsBus(MotorsBus): + """ + The Dynamixel implementation for a MotorsBus. It relies on the python dynamixel sdk to communicate with + the motors. For more info, see the Dynamixel SDK Documentation: + https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20 + """ + + apply_drive_mode = False + available_baudrates = deepcopy(AVAILABLE_BAUDRATES) + default_baudrate = DEFAULT_BAUDRATE + default_timeout = DEFAULT_TIMEOUT_MS + model_baudrate_table = deepcopy(MODEL_BAUDRATE_TABLE) + model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE) + model_encoding_table = deepcopy(MODEL_ENCODING_TABLE) + model_number_table = deepcopy(MODEL_NUMBER_TABLE) + model_resolution_table = deepcopy(MODEL_RESOLUTION) + normalized_data = deepcopy(NORMALIZED_DATA) + + def __init__( + self, + port: str, + motors: dict[str, Motor], + calibration: dict[str, MotorCalibration] | None = None, + ): + super().__init__(port, motors, calibration) + import dynamixel_sdk as dxl + + self.port_handler = dxl.PortHandler(self.port) + self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION) + self.sync_reader = dxl.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0) + self.sync_writer = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, 0, 0) + self._comm_success = dxl.COMM_SUCCESS + self._no_error = 0x00 + + def _assert_protocol_is_compatible(self, instruction_name: str) -> None: + pass + + def _handshake(self) -> None: + self._assert_motors_exist() + + def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]: + model = self.motors[motor].model + search_baudrates = ( + [initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model] + ) + + for baudrate in search_baudrates: + self.set_baudrate(baudrate) + id_model = self.broadcast_ping() + if id_model: + found_id, found_model = next(iter(id_model.items())) + expected_model_nb = self.model_number_table[model] + if found_model != expected_model_nb: + raise RuntimeError( + f"Found one motor on {baudrate=} with id={found_id} but it has a " + f"model number '{found_model}' different than the one expected: '{expected_model_nb}'. " + f"Make sure you are connected only connected to the '{motor}' motor (model '{model}')." + ) + return baudrate, found_id + + raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.") + + def configure_motors(self) -> None: + # By default, Dynamixel motors have a 500µs delay response time (corresponding to a value of 250 on + # the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0). + for motor in self.motors: + self.write("Return_Delay_Time", motor, 0) + + @property + def is_calibrated(self) -> bool: + return self.calibration == self.read_calibration() + + def read_calibration(self) -> dict[str, MotorCalibration]: + offsets = self.sync_read("Homing_Offset", normalize=False) + mins = self.sync_read("Min_Position_Limit", normalize=False) + maxes = self.sync_read("Max_Position_Limit", normalize=False) + drive_modes = self.sync_read("Drive_Mode", normalize=False) + + calibration = {} + for motor, m in self.motors.items(): + calibration[motor] = MotorCalibration( + id=m.id, + drive_mode=drive_modes[motor], + homing_offset=offsets[motor], + range_min=mins[motor], + range_max=maxes[motor], + ) + + return calibration + + def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None: + for motor, calibration in calibration_dict.items(): + self.write("Homing_Offset", motor, calibration.homing_offset) + self.write("Min_Position_Limit", motor, calibration.range_min) + self.write("Max_Position_Limit", motor, calibration.range_max) + + self.calibration = calibration_dict + + def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + for motor in self._get_motors_list(motors): + self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry) + + def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None: + addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable") + self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry) + + def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + for motor in self._get_motors_list(motors): + self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry) + + def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]: + for id_ in ids_values: + model = self._id_to_model(id_) + encoding_table = self.model_encoding_table.get(model) + if encoding_table and data_name in encoding_table: + n_bytes = encoding_table[data_name] + ids_values[id_] = encode_twos_complement(ids_values[id_], n_bytes) + + return ids_values + + def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]: + for id_ in ids_values: + model = self._id_to_model(id_) + encoding_table = self.model_encoding_table.get(model) + if encoding_table and data_name in encoding_table: + n_bytes = encoding_table[data_name] + ids_values[id_] = decode_twos_complement(ids_values[id_], n_bytes) + + return ids_values + + def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]: + """ + On Dynamixel Motors: + Present_Position = Actual_Position + Homing_Offset + """ + half_turn_homings = {} + for motor, pos in positions.items(): + model = self._get_motor_model(motor) + max_res = self.model_resolution_table[model] - 1 + half_turn_homings[motor] = int(max_res / 2) - pos + + return half_turn_homings + + def _split_into_byte_chunks(self, value: int, length: int) -> list[int]: + return _split_into_byte_chunks(value, length) + + def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None: + for n_try in range(1 + num_retry): + data_list, comm = self.packet_handler.broadcastPing(self.port_handler) + if self._is_comm_success(comm): + break + logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})") + logger.debug(self.packet_handler.getTxRxResult(comm)) + + if not self._is_comm_success(comm): + if raise_on_error: + raise ConnectionError(self.packet_handler.getTxRxResult(comm)) + + return + + return {id_: data[0] for id_, data in data_list.items()} diff --git a/lerobot/common/motors/dynamixel/tables.py b/lerobot/common/motors/dynamixel/tables.py new file mode 100644 index 000000000..8b67bbf38 --- /dev/null +++ b/lerobot/common/motors/dynamixel/tables.py @@ -0,0 +1,197 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO(Steven): Consider doing the following: +# from enum import Enum +# class MyControlTableKey(Enum): +# ID = "ID" +# GOAL_SPEED = "Goal_Speed" +# ... +# +# MY_CONTROL_TABLE ={ +# MyControlTableKey.ID.value: (5,1) +# MyControlTableKey.GOAL_SPEED.value: (46, 2) +# ... +# } +# This allows me do to: +# bus.write(MyControlTableKey.GOAL_SPEED, ...) +# Instead of: +# bus.write("Goal_Speed", ...) +# This is important for two reasons: +# 1. The linter will tell me if I'm trying to use an invalid key, instead of me realizing when I get the RunTimeError +# 2. We can change the value of the MyControlTableKey enums without impacting the client code + + +# {data_name: (address, size_byte)} +# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#control-table +X_SERIES_CONTROL_TABLE = { + "Model_Number": (0, 2), + "Model_Information": (2, 4), + "Firmware_Version": (6, 1), + "ID": (7, 1), + "Baud_Rate": (8, 1), + "Return_Delay_Time": (9, 1), + "Drive_Mode": (10, 1), + "Operating_Mode": (11, 1), + "Secondary_ID": (12, 1), + "Protocol_Type": (13, 1), + "Homing_Offset": (20, 4), + "Moving_Threshold": (24, 4), + "Temperature_Limit": (31, 1), + "Max_Voltage_Limit": (32, 2), + "Min_Voltage_Limit": (34, 2), + "PWM_Limit": (36, 2), + "Current_Limit": (38, 2), + "Acceleration_Limit": (40, 4), + "Velocity_Limit": (44, 4), + "Max_Position_Limit": (48, 4), + "Min_Position_Limit": (52, 4), + "Shutdown": (63, 1), + "Torque_Enable": (64, 1), + "LED": (65, 1), + "Status_Return_Level": (68, 1), + "Registered_Instruction": (69, 1), + "Hardware_Error_Status": (70, 1), + "Velocity_I_Gain": (76, 2), + "Velocity_P_Gain": (78, 2), + "Position_D_Gain": (80, 2), + "Position_I_Gain": (82, 2), + "Position_P_Gain": (84, 2), + "Feedforward_2nd_Gain": (88, 2), + "Feedforward_1st_Gain": (90, 2), + "Bus_Watchdog": (98, 1), + "Goal_PWM": (100, 2), + "Goal_Current": (102, 2), + "Goal_Velocity": (104, 4), + "Profile_Acceleration": (108, 4), + "Profile_Velocity": (112, 4), + "Goal_Position": (116, 4), + "Realtime_Tick": (120, 2), + "Moving": (122, 1), + "Moving_Status": (123, 1), + "Present_PWM": (124, 2), + "Present_Current": (126, 2), + "Present_Velocity": (128, 4), + "Present_Position": (132, 4), + "Velocity_Trajectory": (136, 4), + "Position_Trajectory": (140, 4), + "Present_Input_Voltage": (144, 2), + "Present_Temperature": (146, 1), +} + +# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#baud-rate8 +X_SERIES_BAUDRATE_TABLE = { + 9_600: 0, + 57_600: 1, + 115_200: 2, + 1_000_000: 3, + 2_000_000: 4, + 3_000_000: 5, + 4_000_000: 6, +} + +# {data_name: size_byte} +X_SERIES_ENCODINGS_TABLE = { + "Homing_Offset": X_SERIES_CONTROL_TABLE["Homing_Offset"][1], + "Goal_PWM": X_SERIES_CONTROL_TABLE["Goal_PWM"][1], + "Goal_Current": X_SERIES_CONTROL_TABLE["Goal_Current"][1], + "Goal_Velocity": X_SERIES_CONTROL_TABLE["Goal_Velocity"][1], + "Present_PWM": X_SERIES_CONTROL_TABLE["Present_PWM"][1], + "Present_Current": X_SERIES_CONTROL_TABLE["Present_Current"][1], + "Present_Velocity": X_SERIES_CONTROL_TABLE["Present_Velocity"][1], +} + +MODEL_ENCODING_TABLE = { + "x_series": X_SERIES_ENCODINGS_TABLE, + "xl330-m077": X_SERIES_ENCODINGS_TABLE, + "xl330-m288": X_SERIES_ENCODINGS_TABLE, + "xl430-w250": X_SERIES_ENCODINGS_TABLE, + "xm430-w350": X_SERIES_ENCODINGS_TABLE, + "xm540-w270": X_SERIES_ENCODINGS_TABLE, + "xc430-w150": X_SERIES_ENCODINGS_TABLE, +} + +# {model: model_resolution} +# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#specifications +MODEL_RESOLUTION = { + "x_series": 4096, + "xl330-m077": 4096, + "xl330-m288": 4096, + "xl430-w250": 4096, + "xm430-w350": 4096, + "xm540-w270": 4096, + "xc430-w150": 4096, +} + +# {model: model_number} +# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#control-table-of-eeprom-area +MODEL_NUMBER_TABLE = { + "xl330-m077": 1190, + "xl330-m288": 1200, + "xl430-w250": 1060, + "xm430-w350": 1020, + "xm540-w270": 1120, + "xc430-w150": 1070, +} + +# {model: available_operating_modes} +# https://emanual.robotis.com/docs/en/dxl/x/{MODEL}/#operating-mode11 +MODEL_OPERATING_MODES = { + "xl330-m077": [0, 1, 3, 4, 5, 16], + "xl330-m288": [0, 1, 3, 4, 5, 16], + "xl430-w250": [1, 3, 4, 16], + "xm430-w350": [0, 1, 3, 4, 5, 16], + "xm540-w270": [0, 1, 3, 4, 5, 16], + "xc430-w150": [1, 3, 4, 16], +} + +MODEL_CONTROL_TABLE = { + "x_series": X_SERIES_CONTROL_TABLE, + "xl330-m077": X_SERIES_CONTROL_TABLE, + "xl330-m288": X_SERIES_CONTROL_TABLE, + "xl430-w250": X_SERIES_CONTROL_TABLE, + "xm430-w350": X_SERIES_CONTROL_TABLE, + "xm540-w270": X_SERIES_CONTROL_TABLE, + "xc430-w150": X_SERIES_CONTROL_TABLE, +} + +MODEL_BAUDRATE_TABLE = { + "x_series": X_SERIES_BAUDRATE_TABLE, + "xl330-m077": X_SERIES_BAUDRATE_TABLE, + "xl330-m288": X_SERIES_BAUDRATE_TABLE, + "xl430-w250": X_SERIES_BAUDRATE_TABLE, + "xm430-w350": X_SERIES_BAUDRATE_TABLE, + "xm540-w270": X_SERIES_BAUDRATE_TABLE, + "xc430-w150": X_SERIES_BAUDRATE_TABLE, +} + +AVAILABLE_BAUDRATES = [ + 9_600, + 19_200, + 38_400, + 57_600, + 115_200, + 230_400, + 460_800, + 500_000, + 576_000, + 921_600, + 1_000_000, + 1_152_000, + 2_000_000, + 2_500_000, + 3_000_000, + 3_500_000, + 4_000_000, +] diff --git a/lerobot/common/motors/feetech/__init__.py b/lerobot/common/motors/feetech/__init__.py new file mode 100644 index 000000000..911d1d19f --- /dev/null +++ b/lerobot/common/motors/feetech/__init__.py @@ -0,0 +1,2 @@ +from .feetech import DriveMode, FeetechMotorsBus, OperatingMode, TorqueMode +from .tables import * diff --git a/lerobot/common/motors/feetech/feetech.py b/lerobot/common/motors/feetech/feetech.py new file mode 100644 index 000000000..4937fdea7 --- /dev/null +++ b/lerobot/common/motors/feetech/feetech.py @@ -0,0 +1,454 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from copy import deepcopy +from enum import Enum +from pprint import pformat + +from lerobot.common.utils.encoding_utils import decode_sign_magnitude, encode_sign_magnitude + +from ..motors_bus import Motor, MotorCalibration, MotorsBus, NameOrID, Value, get_address +from .tables import ( + FIRMWARE_MAJOR_VERSION, + FIRMWARE_MINOR_VERSION, + MODEL_BAUDRATE_TABLE, + MODEL_CONTROL_TABLE, + MODEL_ENCODING_TABLE, + MODEL_NUMBER, + MODEL_NUMBER_TABLE, + MODEL_PROTOCOL, + MODEL_RESOLUTION, + SCAN_BAUDRATES, +) + +DEFAULT_PROTOCOL_VERSION = 0 +DEFAULT_BAUDRATE = 1_000_000 +DEFAULT_TIMEOUT_MS = 1000 + +NORMALIZED_DATA = ["Goal_Position", "Present_Position"] + +logger = logging.getLogger(__name__) + + +class OperatingMode(Enum): + # position servo mode + POSITION = 0 + # The motor is in constant speed mode, which is controlled by parameter 0x2e, and the highest bit 15 is + # the direction bit + VELOCITY = 1 + # PWM open-loop speed regulation mode, with parameter 0x2c running time parameter control, bit11 as + # direction bit + PWM = 2 + # In step servo mode, the number of step progress is represented by parameter 0x2a, and the highest bit 15 + # is the direction bit + STEP = 3 + + +class DriveMode(Enum): + NON_INVERTED = 0 + INVERTED = 1 + + +class TorqueMode(Enum): + ENABLED = 1 + DISABLED = 0 + + +def _split_into_byte_chunks(value: int, length: int) -> list[int]: + import scservo_sdk as scs + + if length == 1: + data = [value] + elif length == 2: + data = [scs.SCS_LOBYTE(value), scs.SCS_HIBYTE(value)] + elif length == 4: + data = [ + scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), + scs.SCS_HIBYTE(scs.SCS_LOWORD(value)), + scs.SCS_LOBYTE(scs.SCS_HIWORD(value)), + scs.SCS_HIBYTE(scs.SCS_HIWORD(value)), + ] + return data + + +def patch_setPacketTimeout(self, packet_length): # noqa: N802 + """ + HACK: This patches the PortHandler behavior to set the correct packet timeouts. + + It fixes https://gitee.com/ftservo/SCServoSDK/issues/IBY2S6 + The bug is fixed on the official Feetech SDK repo (https://gitee.com/ftservo/FTServo_Python) + but because that version is not published on PyPI, we rely on the (unofficial) on that is, which needs + patching. + """ + self.packet_start_time = self.getCurrentTime() + self.packet_timeout = (self.tx_time_per_byte * packet_length) + (self.tx_time_per_byte * 3.0) + 50 + + +class FeetechMotorsBus(MotorsBus): + """ + The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on the + python feetech sdk to communicate with the motors, which is itself based on the dynamixel sdk. + """ + + apply_drive_mode = True + available_baudrates = deepcopy(SCAN_BAUDRATES) + default_baudrate = DEFAULT_BAUDRATE + default_timeout = DEFAULT_TIMEOUT_MS + model_baudrate_table = deepcopy(MODEL_BAUDRATE_TABLE) + model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE) + model_encoding_table = deepcopy(MODEL_ENCODING_TABLE) + model_number_table = deepcopy(MODEL_NUMBER_TABLE) + model_resolution_table = deepcopy(MODEL_RESOLUTION) + normalized_data = deepcopy(NORMALIZED_DATA) + + def __init__( + self, + port: str, + motors: dict[str, Motor], + calibration: dict[str, MotorCalibration] | None = None, + protocol_version: int = DEFAULT_PROTOCOL_VERSION, + ): + super().__init__(port, motors, calibration) + self.protocol_version = protocol_version + self._assert_same_protocol() + import scservo_sdk as scs + + self.port_handler = scs.PortHandler(self.port) + # HACK: monkeypatch + self.port_handler.setPacketTimeout = patch_setPacketTimeout.__get__( + self.port_handler, scs.PortHandler + ) + self.packet_handler = scs.PacketHandler(protocol_version) + self.sync_reader = scs.GroupSyncRead(self.port_handler, self.packet_handler, 0, 0) + self.sync_writer = scs.GroupSyncWrite(self.port_handler, self.packet_handler, 0, 0) + self._comm_success = scs.COMM_SUCCESS + self._no_error = 0x00 + + if any(MODEL_PROTOCOL[model] != self.protocol_version for model in self.models): + raise ValueError(f"Some motors are incompatible with protocol_version={self.protocol_version}") + + def _assert_same_protocol(self) -> None: + if any(MODEL_PROTOCOL[model] != self.protocol_version for model in self.models): + raise RuntimeError("Some motors use an incompatible protocol.") + + def _assert_protocol_is_compatible(self, instruction_name: str) -> None: + if instruction_name == "sync_read" and self.protocol_version == 1: + raise NotImplementedError( + "'Sync Read' is not available with Feetech motors using Protocol 1. Use 'Read' sequentially instead." + ) + if instruction_name == "broadcast_ping" and self.protocol_version == 1: + raise NotImplementedError( + "'Broadcast Ping' is not available with Feetech motors using Protocol 1. Use 'Ping' sequentially instead." + ) + + def _assert_same_firmware(self) -> None: + firmware_versions = self._read_firmware_version(self.ids, raise_on_error=True) + if len(set(firmware_versions.values())) != 1: + raise RuntimeError( + "Some Motors use different firmware versions:" + f"\n{pformat(firmware_versions)}\n" + "Update their firmware first using Feetech's software. " + "Visit https://www.feetechrc.com/software." + ) + + def _handshake(self) -> None: + self._assert_motors_exist() + self._assert_same_firmware() + + def _find_single_motor(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]: + if self.protocol_version == 0: + return self._find_single_motor_p0(motor, initial_baudrate) + else: + return self._find_single_motor_p1(motor, initial_baudrate) + + def _find_single_motor_p0(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]: + model = self.motors[motor].model + search_baudrates = ( + [initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model] + ) + expected_model_nb = self.model_number_table[model] + + for baudrate in search_baudrates: + self.set_baudrate(baudrate) + id_model = self.broadcast_ping() + if id_model: + found_id, found_model = next(iter(id_model.items())) + if found_model != expected_model_nb: + raise RuntimeError( + f"Found one motor on {baudrate=} with id={found_id} but it has a " + f"model number '{found_model}' different than the one expected: '{expected_model_nb}'. " + f"Make sure you are connected only connected to the '{motor}' motor (model '{model}')." + ) + return baudrate, found_id + + raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.") + + def _find_single_motor_p1(self, motor: str, initial_baudrate: int | None = None) -> tuple[int, int]: + import scservo_sdk as scs + + model = self.motors[motor].model + search_baudrates = ( + [initial_baudrate] if initial_baudrate is not None else self.model_baudrate_table[model] + ) + expected_model_nb = self.model_number_table[model] + + for baudrate in search_baudrates: + self.set_baudrate(baudrate) + for id_ in range(scs.MAX_ID + 1): + found_model = self.ping(id_) + if found_model is not None: + if found_model != expected_model_nb: + raise RuntimeError( + f"Found one motor on {baudrate=} with id={id_} but it has a " + f"model number '{found_model}' different than the one expected: '{expected_model_nb}'. " + f"Make sure you are connected only connected to the '{motor}' motor (model '{model}')." + ) + return baudrate, id_ + + raise RuntimeError(f"Motor '{motor}' (model '{model}') was not found. Make sure it is connected.") + + def configure_motors(self) -> None: + for motor in self.motors: + # By default, Feetech motors have a 500µs delay response time (corresponding to a value of 250 on + # the 'Return_Delay_Time' address). We ensure this is reduced to the minimum of 2µs (value of 0). + self.write("Return_Delay_Time", motor, 0) + # Set 'Maximum_Acceleration' to 254 to speedup acceleration and deceleration of the motors. + # Note: this address is not in the official STS3215 Memory Table + self.write("Maximum_Acceleration", motor, 254) + self.write("Acceleration", motor, 254) + + @property + def is_calibrated(self) -> bool: + motors_calibration = self.read_calibration() + if set(motors_calibration) != set(self.calibration): + return False + + same_ranges = all( + self.calibration[motor].range_min == cal.range_min + and self.calibration[motor].range_max == cal.range_max + for motor, cal in motors_calibration.items() + ) + if self.protocol_version == 1: + return same_ranges + + same_offsets = all( + self.calibration[motor].homing_offset == cal.homing_offset + for motor, cal in motors_calibration.items() + ) + return same_ranges and same_offsets + + def read_calibration(self) -> dict[str, MotorCalibration]: + offsets, mins, maxes = {}, {}, {} + for motor in self.motors: + mins[motor] = self.read("Min_Position_Limit", motor, normalize=False) + maxes[motor] = self.read("Max_Position_Limit", motor, normalize=False) + offsets[motor] = ( + self.read("Homing_Offset", motor, normalize=False) if self.protocol_version == 0 else 0 + ) + + calibration = {} + for motor, m in self.motors.items(): + calibration[motor] = MotorCalibration( + id=m.id, + drive_mode=0, + homing_offset=offsets[motor], + range_min=mins[motor], + range_max=maxes[motor], + ) + + return calibration + + def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None: + for motor, calibration in calibration_dict.items(): + if self.protocol_version == 0: + self.write("Homing_Offset", motor, calibration.homing_offset) + self.write("Min_Position_Limit", motor, calibration.range_min) + self.write("Max_Position_Limit", motor, calibration.range_max) + + self.calibration = calibration_dict + + def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]: + """ + On Feetech Motors: + Present_Position = Actual_Position - Homing_Offset + """ + half_turn_homings = {} + for motor, pos in positions.items(): + model = self._get_motor_model(motor) + max_res = self.model_resolution_table[model] - 1 + half_turn_homings[motor] = pos - int(max_res / 2) + + return half_turn_homings + + def disable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + for motor in self._get_motors_list(motors): + self.write("Torque_Enable", motor, TorqueMode.DISABLED.value, num_retry=num_retry) + self.write("Lock", motor, 0, num_retry=num_retry) + + def _disable_torque(self, motor_id: int, model: str, num_retry: int = 0) -> None: + addr, length = get_address(self.model_ctrl_table, model, "Torque_Enable") + self._write(addr, length, motor_id, TorqueMode.DISABLED.value, num_retry=num_retry) + addr, length = get_address(self.model_ctrl_table, model, "Lock") + self._write(addr, length, motor_id, 0, num_retry=num_retry) + + def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + for motor in self._get_motors_list(motors): + self.write("Torque_Enable", motor, TorqueMode.ENABLED.value, num_retry=num_retry) + self.write("Lock", motor, 1, num_retry=num_retry) + + def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]: + for id_ in ids_values: + model = self._id_to_model(id_) + encoding_table = self.model_encoding_table.get(model) + if encoding_table and data_name in encoding_table: + sign_bit = encoding_table[data_name] + ids_values[id_] = encode_sign_magnitude(ids_values[id_], sign_bit) + + return ids_values + + def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]: + for id_ in ids_values: + model = self._id_to_model(id_) + encoding_table = self.model_encoding_table.get(model) + if encoding_table and data_name in encoding_table: + sign_bit = encoding_table[data_name] + ids_values[id_] = decode_sign_magnitude(ids_values[id_], sign_bit) + + return ids_values + + def _split_into_byte_chunks(self, value: int, length: int) -> list[int]: + return _split_into_byte_chunks(value, length) + + def _broadcast_ping(self) -> tuple[dict[int, int], int]: + import scservo_sdk as scs + + data_list = {} + + status_length = 6 + + rx_length = 0 + wait_length = status_length * scs.MAX_ID + + txpacket = [0] * 6 + + tx_time_per_byte = (1000.0 / self.port_handler.getBaudRate()) * 10.0 + + txpacket[scs.PKT_ID] = scs.BROADCAST_ID + txpacket[scs.PKT_LENGTH] = 2 + txpacket[scs.PKT_INSTRUCTION] = scs.INST_PING + + result = self.packet_handler.txPacket(self.port_handler, txpacket) + if result != scs.COMM_SUCCESS: + self.port_handler.is_using = False + return data_list, result + + # set rx timeout + self.port_handler.setPacketTimeoutMillis((wait_length * tx_time_per_byte) + (3.0 * scs.MAX_ID) + 16.0) + + rxpacket = [] + while not self.port_handler.isPacketTimeout() and rx_length < wait_length: + rxpacket += self.port_handler.readPort(wait_length - rx_length) + rx_length = len(rxpacket) + + self.port_handler.is_using = False + + if rx_length == 0: + return data_list, scs.COMM_RX_TIMEOUT + + while True: + if rx_length < status_length: + return data_list, scs.COMM_RX_CORRUPT + + # find packet header + for idx in range(0, (rx_length - 1)): + if (rxpacket[idx] == 0xFF) and (rxpacket[idx + 1] == 0xFF): + break + + if idx == 0: # found at the beginning of the packet + # calculate checksum + checksum = 0 + for idx in range(2, status_length - 1): # except header & checksum + checksum += rxpacket[idx] + + checksum = ~checksum & 0xFF + if rxpacket[status_length - 1] == checksum: + result = scs.COMM_SUCCESS + data_list[rxpacket[scs.PKT_ID]] = rxpacket[scs.PKT_ERROR] + + del rxpacket[0:status_length] + rx_length = rx_length - status_length + + if rx_length == 0: + return data_list, result + else: + result = scs.COMM_RX_CORRUPT + # remove header (0xFF 0xFF) + del rxpacket[0:2] + rx_length = rx_length - 2 + else: + # remove unnecessary packets + del rxpacket[0:idx] + rx_length = rx_length - idx + + def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None: + self._assert_protocol_is_compatible("broadcast_ping") + for n_try in range(1 + num_retry): + ids_status, comm = self._broadcast_ping() + if self._is_comm_success(comm): + break + logger.debug(f"Broadcast ping failed on port '{self.port}' ({n_try=})") + logger.debug(self.packet_handler.getTxRxResult(comm)) + + if not self._is_comm_success(comm): + if raise_on_error: + raise ConnectionError(self.packet_handler.getTxRxResult(comm)) + return + + ids_errors = {id_: status for id_, status in ids_status.items() if self._is_error(status)} + if ids_errors: + display_dict = {id_: self.packet_handler.getRxPacketError(err) for id_, err in ids_errors.items()} + logger.error(f"Some motors found returned an error status:\n{pformat(display_dict, indent=4)}") + + return self._read_model_number(list(ids_status), raise_on_error) + + def _read_firmware_version(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, str]: + firmware_versions = {} + for id_ in motor_ids: + firm_ver_major, comm, error = self._read( + *FIRMWARE_MAJOR_VERSION, id_, raise_on_error=raise_on_error + ) + if not self._is_comm_success(comm) or self._is_error(error): + continue + + firm_ver_minor, comm, error = self._read( + *FIRMWARE_MINOR_VERSION, id_, raise_on_error=raise_on_error + ) + if not self._is_comm_success(comm) or self._is_error(error): + continue + + firmware_versions[id_] = f"{firm_ver_major}.{firm_ver_minor}" + + return firmware_versions + + def _read_model_number(self, motor_ids: list[int], raise_on_error: bool = False) -> dict[int, int]: + model_numbers = {} + for id_ in motor_ids: + model_nb, comm, error = self._read(*MODEL_NUMBER, id_, raise_on_error=raise_on_error) + if not self._is_comm_success(comm) or self._is_error(error): + continue + + model_numbers[id_] = model_nb + + return model_numbers diff --git a/lerobot/common/motors/feetech/tables.py b/lerobot/common/motors/feetech/tables.py new file mode 100644 index 000000000..0a2f2659f --- /dev/null +++ b/lerobot/common/motors/feetech/tables.py @@ -0,0 +1,252 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FIRMWARE_MAJOR_VERSION = (0, 1) +FIRMWARE_MINOR_VERSION = (1, 1) +MODEL_NUMBER = (3, 2) + +# TODO(Steven): Consider doing the following: +# from enum import Enum +# class MyControlTableKey(Enum): +# ID = "ID" +# GOAL_SPEED = "Goal_Speed" +# ... +# +# MY_CONTROL_TABLE ={ +# MyControlTableKey.ID.value: (5,1) +# MyControlTableKey.GOAL_SPEED.value: (46, 2) +# ... +# } +# This allows me do to: +# bus.write(MyControlTableKey.GOAL_SPEED, ...) +# Instead of: +# bus.write("Goal_Speed", ...) +# This is important for two reasons: +# 1. The linter will tell me if I'm trying to use an invalid key, instead of me realizing when I get the RunTimeError +# 2. We can change the value of the MyControlTableKey enums without impacting the client code + +# data_name: (address, size_byte) +# http://doc.feetech.cn/#/prodinfodownload?srcType=FT-SMS-STS-emanual-229f4476422d4059abfb1cb0 +STS_SMS_SERIES_CONTROL_TABLE = { + # EPROM + "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only + "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only + "Model_Number": MODEL_NUMBER, # read-only + "ID": (5, 1), + "Baud_Rate": (6, 1), + "Return_Delay_Time": (7, 1), + "Response_Status_Level": (8, 1), + "Min_Position_Limit": (9, 2), + "Max_Position_Limit": (11, 2), + "Max_Temperature_Limit": (13, 1), + "Max_Voltage_Limit": (14, 1), + "Min_Voltage_Limit": (15, 1), + "Max_Torque_Limit": (16, 2), + "Phase": (18, 1), + "Unloading_Condition": (19, 1), + "LED_Alarm_Condition": (20, 1), + "P_Coefficient": (21, 1), + "D_Coefficient": (22, 1), + "I_Coefficient": (23, 1), + "Minimum_Startup_Force": (24, 2), + "CW_Dead_Zone": (26, 1), + "CCW_Dead_Zone": (27, 1), + "Protection_Current": (28, 2), + "Angular_Resolution": (30, 1), + "Homing_Offset": (31, 2), + "Operating_Mode": (33, 1), + "Protective_Torque": (34, 1), + "Protection_Time": (35, 1), + "Overload_Torque": (36, 1), + "Velocity_closed_loop_P_proportional_coefficient": (37, 1), + "Over_Current_Protection_Time": (38, 1), + "Velocity_closed_loop_I_integral_coefficient": (39, 1), + # SRAM + "Torque_Enable": (40, 1), + "Acceleration": (41, 1), + "Goal_Position": (42, 2), + "Goal_Time": (44, 2), + "Goal_Velocity": (46, 2), + "Torque_Limit": (48, 2), + "Lock": (55, 1), + "Present_Position": (56, 2), # read-only + "Present_Velocity": (58, 2), # read-only + "Present_Load": (60, 2), # read-only + "Present_Voltage": (62, 1), # read-only + "Present_Temperature": (63, 1), # read-only + "Status": (65, 1), # read-only + "Moving": (66, 1), # read-only + "Present_Current": (69, 2), # read-only + "Goal_Position_2": (71, 2), # read-only + # Factory + "Moving_Velocity": (80, 1), + "Moving_Velocity_Threshold": (80, 1), + "DTs": (81, 1), # (ms) + "Velocity_Unit_factor": (82, 1), + "Hts": (83, 1), # (ns) valid for firmware >= 2.54, other versions keep 0 + "Maximum_Velocity_Limit": (84, 1), + "Maximum_Acceleration": (85, 1), + "Acceleration_Multiplier ": (86, 1), # Acceleration multiplier in effect when acceleration is 0 +} + +# http://doc.feetech.cn/#/prodinfodownload?srcType=FT-SCSCL-emanual-cbcc8ab2e3384282a01d4bf3 +SCS_SERIES_CONTROL_TABLE = { + # EPROM + "Firmware_Major_Version": FIRMWARE_MAJOR_VERSION, # read-only + "Firmware_Minor_Version": FIRMWARE_MINOR_VERSION, # read-only + "Model_Number": MODEL_NUMBER, # read-only + "ID": (5, 1), + "Baud_Rate": (6, 1), + "Return_Delay_Time": (7, 1), + "Response_Status_Level": (8, 1), + "Min_Position_Limit": (9, 2), + "Max_Position_Limit": (11, 2), + "Max_Temperature_Limit": (13, 1), + "Max_Voltage_Limit": (14, 1), + "Min_Voltage_Limit": (15, 1), + "Max_Torque_Limit": (16, 2), + "Phase": (18, 1), + "Unloading_Condition": (19, 1), + "LED_Alarm_Condition": (20, 1), + "P_Coefficient": (21, 1), + "D_Coefficient": (22, 1), + "I_Coefficient": (23, 1), + "Minimum_Startup_Force": (24, 2), + "CW_Dead_Zone": (26, 1), + "CCW_Dead_Zone": (27, 1), + "Protective_Torque": (37, 1), + "Protection_Time": (38, 1), + # SRAM + "Torque_Enable": (40, 1), + "Acceleration": (41, 1), + "Goal_Position": (42, 2), + "Running_Time": (44, 2), + "Goal_Velocity": (46, 2), + "Lock": (48, 1), + "Present_Position": (56, 2), # read-only + "Present_Velocity": (58, 2), # read-only + "Present_Load": (60, 2), # read-only + "Present_Voltage": (62, 1), # read-only + "Present_Temperature": (63, 1), # read-only + "Sync_Write_Flag": (64, 1), # read-only + "Status": (65, 1), # read-only + "Moving": (66, 1), # read-only + # Factory + "PWM_Maximum_Step": (78, 1), + "Moving_Velocity_Threshold*50": (79, 1), + "DTs": (80, 1), # (ms) + "Minimum_Velocity_Limit*50": (81, 1), + "Maximum_Velocity_Limit*50": (82, 1), + "Acceleration_2": (83, 1), # don't know what that is +} + +STS_SMS_SERIES_BAUDRATE_TABLE = { + 1_000_000: 0, + 500_000: 1, + 250_000: 2, + 128_000: 3, + 115_200: 4, + 57_600: 5, + 38_400: 6, + 19_200: 7, +} + +SCS_SERIES_BAUDRATE_TABLE = { + 1_000_000: 0, + 500_000: 1, + 250_000: 2, + 128_000: 3, + 115_200: 4, + 57_600: 5, + 38_400: 6, + 19_200: 7, +} + +MODEL_CONTROL_TABLE = { + "sts_series": STS_SMS_SERIES_CONTROL_TABLE, + "scs_series": SCS_SERIES_CONTROL_TABLE, + "sms_series": STS_SMS_SERIES_CONTROL_TABLE, + "sts3215": STS_SMS_SERIES_CONTROL_TABLE, + "sts3250": STS_SMS_SERIES_CONTROL_TABLE, + "scs0009": SCS_SERIES_CONTROL_TABLE, + "sm8512bl": STS_SMS_SERIES_CONTROL_TABLE, +} + +MODEL_RESOLUTION = { + "sts_series": 4096, + "sms_series": 4096, + "scs_series": 1024, + "sts3215": 4096, + "sts3250": 4096, + "sm8512bl": 65536, + "scs0009": 1024, +} + +MODEL_BAUDRATE_TABLE = { + "sts_series": STS_SMS_SERIES_BAUDRATE_TABLE, + "sms_series": STS_SMS_SERIES_BAUDRATE_TABLE, + "scs_series": SCS_SERIES_BAUDRATE_TABLE, + "sm8512bl": STS_SMS_SERIES_BAUDRATE_TABLE, + "sts3215": STS_SMS_SERIES_BAUDRATE_TABLE, + "sts3250": STS_SMS_SERIES_BAUDRATE_TABLE, + "scs0009": SCS_SERIES_BAUDRATE_TABLE, +} + +# Sign-Magnitude encoding bits +STS_SMS_SERIES_ENCODINGS_TABLE = { + "Homing_Offset": 11, + "Goal_Velocity": 15, + "Present_Velocity": 15, +} + +MODEL_ENCODING_TABLE = { + "sts_series": STS_SMS_SERIES_ENCODINGS_TABLE, + "sms_series": STS_SMS_SERIES_ENCODINGS_TABLE, + "scs_series": {}, + "sts3215": STS_SMS_SERIES_ENCODINGS_TABLE, + "sts3250": STS_SMS_SERIES_ENCODINGS_TABLE, + "sm8512bl": STS_SMS_SERIES_ENCODINGS_TABLE, + "scs0009": {}, +} + +SCAN_BAUDRATES = [ + 4_800, + 9_600, + 14_400, + 19_200, + 38_400, + 57_600, + 115_200, + 128_000, + 250_000, + 500_000, + 1_000_000, +] + +MODEL_NUMBER_TABLE = { + "sts3215": 777, + "sts3250": 2825, + "sm8512bl": 11272, + "scs0009": 1284, +} + +MODEL_PROTOCOL = { + "sts_series": 0, + "sms_series": 0, + "scs_series": 1, + "sts3215": 0, + "sts3250": 0, + "sm8512bl": 0, + "scs0009": 1, +} diff --git a/lerobot/common/motors/motors_bus.py b/lerobot/common/motors/motors_bus.py new file mode 100644 index 000000000..7ac9e6813 --- /dev/null +++ b/lerobot/common/motors/motors_bus.py @@ -0,0 +1,1219 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa: N802 +# This noqa is for the Protocols classes: PortHandler, PacketHandler GroupSyncRead/Write +# TODO(aliberts): Add block noqa when feature below is available +# https://github.com/astral-sh/ruff/issues/3711 + +import abc +import logging +from contextlib import contextmanager +from dataclasses import dataclass +from enum import Enum +from functools import cached_property +from pprint import pformat +from typing import Protocol, TypeAlias + +import serial +from deepdiff import DeepDiff +from tqdm import tqdm + +from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.common.utils.utils import enter_pressed, move_cursor_up + +NameOrID: TypeAlias = str | int +Value: TypeAlias = int | float + +logger = logging.getLogger(__name__) + + +def get_ctrl_table(model_ctrl_table: dict[str, dict], model: str) -> dict[str, tuple[int, int]]: + ctrl_table = model_ctrl_table.get(model) + if ctrl_table is None: + raise KeyError(f"Control table for {model=} not found.") + return ctrl_table + + +def get_address(model_ctrl_table: dict[str, dict], model: str, data_name: str) -> tuple[int, int]: + ctrl_table = get_ctrl_table(model_ctrl_table, model) + addr_bytes = ctrl_table.get(data_name) + if addr_bytes is None: + raise KeyError(f"Address for '{data_name}' not found in {model} control table.") + return addr_bytes + + +def assert_same_address(model_ctrl_table: dict[str, dict], motor_models: list[str], data_name: str) -> None: + all_addr = [] + all_bytes = [] + for model in motor_models: + addr, bytes = get_address(model_ctrl_table, model, data_name) + all_addr.append(addr) + all_bytes.append(bytes) + + if len(set(all_addr)) != 1: + raise NotImplementedError( + f"At least two motor models use a different address for `data_name`='{data_name}'" + f"({list(zip(motor_models, all_addr, strict=False))})." + ) + + if len(set(all_bytes)) != 1: + raise NotImplementedError( + f"At least two motor models use a different bytes representation for `data_name`='{data_name}'" + f"({list(zip(motor_models, all_bytes, strict=False))})." + ) + + +class MotorNormMode(str, Enum): + RANGE_0_100 = "range_0_100" + RANGE_M100_100 = "range_m100_100" + DEGREES = "degrees" + + +@dataclass +class MotorCalibration: + id: int + drive_mode: int + homing_offset: int + range_min: int + range_max: int + + +@dataclass +class Motor: + id: int + model: str + norm_mode: MotorNormMode + + +class JointOutOfRangeError(Exception): + def __init__(self, message="Joint is out of range"): + self.message = message + super().__init__(self.message) + + +class PortHandler(Protocol): + def __init__(self, port_name): + self.is_open: bool + self.baudrate: int + self.packet_start_time: float + self.packet_timeout: float + self.tx_time_per_byte: float + self.is_using: bool + self.port_name: str + self.ser: serial.Serial + + def openPort(self): ... + def closePort(self): ... + def clearPort(self): ... + def setPortName(self, port_name): ... + def getPortName(self): ... + def setBaudRate(self, baudrate): ... + def getBaudRate(self): ... + def getBytesAvailable(self): ... + def readPort(self, length): ... + def writePort(self, packet): ... + def setPacketTimeout(self, packet_length): ... + def setPacketTimeoutMillis(self, msec): ... + def isPacketTimeout(self): ... + def getCurrentTime(self): ... + def getTimeSinceStart(self): ... + def setupPort(self, cflag_baud): ... + def getCFlagBaud(self, baudrate): ... + + +class PacketHandler(Protocol): + def getTxRxResult(self, result): ... + def getRxPacketError(self, error): ... + def txPacket(self, port, txpacket): ... + def rxPacket(self, port): ... + def txRxPacket(self, port, txpacket): ... + def ping(self, port, id): ... + def action(self, port, id): ... + def readTx(self, port, id, address, length): ... + def readRx(self, port, id, length): ... + def readTxRx(self, port, id, address, length): ... + def read1ByteTx(self, port, id, address): ... + def read1ByteRx(self, port, id): ... + def read1ByteTxRx(self, port, id, address): ... + def read2ByteTx(self, port, id, address): ... + def read2ByteRx(self, port, id): ... + def read2ByteTxRx(self, port, id, address): ... + def read4ByteTx(self, port, id, address): ... + def read4ByteRx(self, port, id): ... + def read4ByteTxRx(self, port, id, address): ... + def writeTxOnly(self, port, id, address, length, data): ... + def writeTxRx(self, port, id, address, length, data): ... + def write1ByteTxOnly(self, port, id, address, data): ... + def write1ByteTxRx(self, port, id, address, data): ... + def write2ByteTxOnly(self, port, id, address, data): ... + def write2ByteTxRx(self, port, id, address, data): ... + def write4ByteTxOnly(self, port, id, address, data): ... + def write4ByteTxRx(self, port, id, address, data): ... + def regWriteTxOnly(self, port, id, address, length, data): ... + def regWriteTxRx(self, port, id, address, length, data): ... + def syncReadTx(self, port, start_address, data_length, param, param_length): ... + def syncWriteTxOnly(self, port, start_address, data_length, param, param_length): ... + + +class GroupSyncRead(Protocol): + def __init__(self, port, ph, start_address, data_length): + self.port: str + self.ph: PortHandler + self.start_address: int + self.data_length: int + self.last_result: bool + self.is_param_changed: bool + self.param: list + self.data_dict: dict + + def makeParam(self): ... + def addParam(self, id): ... + def removeParam(self, id): ... + def clearParam(self): ... + def txPacket(self): ... + def rxPacket(self): ... + def txRxPacket(self): ... + def isAvailable(self, id, address, data_length): ... + def getData(self, id, address, data_length): ... + + +class GroupSyncWrite(Protocol): + def __init__(self, port, ph, start_address, data_length): + self.port: str + self.ph: PortHandler + self.start_address: int + self.data_length: int + self.is_param_changed: bool + self.param: list + self.data_dict: dict + + def makeParam(self): ... + def addParam(self, id, data): ... + def removeParam(self, id): ... + def changeParam(self, id, data): ... + def clearParam(self): ... + def txPacket(self): ... + + +class MotorsBus(abc.ABC): + """ + A MotorsBus allows to efficiently read and write to the attached motors. + It represents several motors daisy-chained together and connected through a serial port. + There are currently two implementations of this abstract class: + - DynamixelMotorsBus + - FeetechMotorsBus + + Note: This class may evolve in the future should we add support for other types of bus. + + A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)). + To find the port, you can run our utility script: + ```bash + python -m lerobot.find_port.py + >>> Finding all available ports for the MotorsBus. + >>> ['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] + >>> Remove the usb cable from your MotorsBus and press Enter when done. + >>> The port of this MotorsBus is /dev/tty.usbmodem575E0031751. + >>> Reconnect the usb cable. + ``` + + Example of usage for 1 Feetech sts3215 motor connected to the bus: + ```python + bus = FeetechMotorsBus( + port="/dev/tty.usbmodem575E0031751", + motors={"my_motor": (1, "sts3215")}, + ) + bus.connect() + + position = bus.read("Present_Position", "my_motor", normalize=False) + + # Move from a few motor steps as an example + few_steps = 30 + bus.write("Goal_Position", "my_motor", position + few_steps, normalize=False) + + # When done, properly disconnect the port using + bus.disconnect() + ``` + """ + + apply_drive_mode: bool + available_baudrates: list[int] + default_baudrate: int + default_timeout: int + model_baudrate_table: dict[str, dict] + model_ctrl_table: dict[str, dict] + model_encoding_table: dict[str, dict] + model_number_table: dict[str, int] + model_resolution_table: dict[str, int] + normalized_data: list[str] + + def __init__( + self, + port: str, + motors: dict[str, Motor], + calibration: dict[str, MotorCalibration] | None = None, + ): + self.port = port + self.motors = motors + self.calibration = calibration if calibration else {} + + self.port_handler: PortHandler + self.packet_handler: PacketHandler + self.sync_reader: GroupSyncRead + self.sync_writer: GroupSyncWrite + self._comm_success: int + self._no_error: int + + self._id_to_model_dict = {m.id: m.model for m in self.motors.values()} + self._id_to_name_dict = {m.id: motor for motor, m in self.motors.items()} + self._model_nb_to_model_dict = {v: k for k, v in self.model_number_table.items()} + + self._validate_motors() + + def __len__(self): + return len(self.motors) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(\n" + f" Port: '{self.port}',\n" + f" Motors: \n{pformat(self.motors, indent=8, sort_dicts=False)},\n" + ")',\n" + ) + + @cached_property + def _has_different_ctrl_tables(self) -> bool: + if len(self.models) < 2: + return False + + first_table = self.model_ctrl_table[self.models[0]] + return any( + DeepDiff(first_table, get_ctrl_table(self.model_ctrl_table, model)) for model in self.models[1:] + ) + + @cached_property + def models(self) -> list[str]: + return [m.model for m in self.motors.values()] + + @cached_property + def ids(self) -> list[int]: + return [m.id for m in self.motors.values()] + + def _model_nb_to_model(self, motor_nb: int) -> str: + return self._model_nb_to_model_dict[motor_nb] + + def _id_to_model(self, motor_id: int) -> str: + return self._id_to_model_dict[motor_id] + + def _id_to_name(self, motor_id: int) -> str: + return self._id_to_name_dict[motor_id] + + def _get_motor_id(self, motor: NameOrID) -> int: + if isinstance(motor, str): + return self.motors[motor].id + elif isinstance(motor, int): + return motor + else: + raise TypeError(f"'{motor}' should be int, str.") + + def _get_motor_model(self, motor: NameOrID) -> int: + if isinstance(motor, str): + return self.motors[motor].model + elif isinstance(motor, int): + return self._id_to_model_dict[motor] + else: + raise TypeError(f"'{motor}' should be int, str.") + + def _get_motors_list(self, motors: str | list[str] | None) -> list[str]: + if motors is None: + return list(self.motors) + elif isinstance(motors, str): + return [motors] + elif isinstance(motors, list): + return motors.copy() + else: + raise TypeError(motors) + + def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]: + if isinstance(values, (int, float)): + return dict.fromkeys(self.ids, values) + elif isinstance(values, dict): + return {self.motors[motor].id: val for motor, val in values.items()} + else: + raise TypeError(f"'values' is expected to be a single value or a dict. Got {values}") + + def _validate_motors(self) -> None: + if len(self.ids) != len(set(self.ids)): + raise ValueError(f"Some motors have the same id!\n{self}") + + # Ensure ctrl table available for all models + for model in self.models: + get_ctrl_table(self.model_ctrl_table, model) + + def _is_comm_success(self, comm: int) -> bool: + return comm == self._comm_success + + def _is_error(self, error: int) -> bool: + return error != self._no_error + + def _assert_motors_exist(self) -> None: + expected_models = {m.id: self.model_number_table[m.model] for m in self.motors.values()} + + found_models = {} + for id_ in self.ids: + model_nb = self.ping(id_) + if model_nb is not None: + found_models[id_] = model_nb + + missing_ids = [id_ for id_ in self.ids if id_ not in found_models] + wrong_models = { + id_: (expected_models[id_], found_models[id_]) + for id_ in found_models + if expected_models.get(id_) != found_models[id_] + } + + if missing_ids or wrong_models: + error_lines = [f"{self.__class__.__name__} motor check failed on port '{self.port}':"] + + if missing_ids: + error_lines.append("\nMissing motor IDs:") + error_lines.extend( + f" - {id_} (expected model: {expected_models[id_]})" for id_ in missing_ids + ) + + if wrong_models: + error_lines.append("\nMotors with incorrect model numbers:") + error_lines.extend( + f" - {id_} ({self._id_to_name(id_)}): expected {expected}, found {found}" + for id_, (expected, found) in wrong_models.items() + ) + + error_lines.append("\nFull expected motor list (id: model_number):") + error_lines.append(pformat(expected_models, indent=4, sort_dicts=False)) + error_lines.append("\nFull found motor list (id: model_number):") + error_lines.append(pformat(found_models, indent=4, sort_dicts=False)) + + raise RuntimeError("\n".join(error_lines)) + + @abc.abstractmethod + def _assert_protocol_is_compatible(self, instruction_name: str) -> None: + pass + + @property + def is_connected(self) -> bool: + """bool: `True` if the underlying serial port is open.""" + return self.port_handler.is_open + + def connect(self, handshake: bool = True) -> None: + """Open the serial port and initialise communication. + + Args: + handshake (bool, optional): Pings every expected motor and performs additional + integrity checks specific to the implementation. Defaults to `True`. + + Raises: + DeviceAlreadyConnectedError: The port is already open. + ConnectionError: The underlying SDK failed to open the port or the handshake did not succeed. + """ + if self.is_connected: + raise DeviceAlreadyConnectedError( + f"{self.__class__.__name__}('{self.port}') is already connected. Do not call `{self.__class__.__name__}.connect()` twice." + ) + + self._connect(handshake) + self.set_timeout() + logger.debug(f"{self.__class__.__name__} connected.") + + def _connect(self, handshake: bool = True) -> None: + try: + if not self.port_handler.openPort(): + raise OSError(f"Failed to open port '{self.port}'.") + elif handshake: + self._handshake() + except (FileNotFoundError, OSError, serial.SerialException) as e: + raise ConnectionError( + f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port." + "\nTry running `python lerobot/find_port.py`\n" + ) from e + + @abc.abstractmethod + def _handshake(self) -> None: + pass + + def disconnect(self, disable_torque: bool = True) -> None: + """Close the serial port (optionally disabling torque first). + + Args: + disable_torque (bool, optional): If `True` (default) torque is disabled on every motor before + closing the port. This can prevent damaging motors if they are left applying resisting torque + after disconnect. + """ + if not self.is_connected: + raise DeviceNotConnectedError( + f"{self.__class__.__name__}('{self.port}') is not connected. Try running `{self.__class__.__name__}.connect()` first." + ) + + if disable_torque: + self.port_handler.clearPort() + self.port_handler.is_using = False + self.disable_torque(num_retry=5) + + self.port_handler.closePort() + logger.debug(f"{self.__class__.__name__} disconnected.") + + @classmethod + def scan_port(cls, port: str, *args, **kwargs) -> dict[int, list[int]]: + """Probe *port* at every supported baud-rate and list responding IDs. + + Args: + port (str): Serial/USB port to scan (e.g. ``"/dev/ttyUSB0"``). + *args, **kwargs: Forwarded to the subclass constructor. + + Returns: + dict[int, list[int]]: Mapping *baud-rate → list of motor IDs* + for every baud-rate that produced at least one response. + """ + bus = cls(port, {}, *args, **kwargs) + bus._connect(handshake=False) + baudrate_ids = {} + for baudrate in tqdm(bus.available_baudrates, desc="Scanning port"): + bus.set_baudrate(baudrate) + ids_models = bus.broadcast_ping() + if ids_models: + tqdm.write(f"Motors found for {baudrate=}: {pformat(ids_models, indent=4)}") + baudrate_ids[baudrate] = list(ids_models) + + bus.port_handler.closePort() + return baudrate_ids + + def setup_motor( + self, motor: str, initial_baudrate: int | None = None, initial_id: int | None = None + ) -> None: + """Assign the correct ID and baud-rate to a single motor. + + This helper temporarily switches to the motor's current settings, disables torque, sets the desired + ID, and finally programs the bus' default baud-rate. + + Args: + motor (str): Key of the motor in :pyattr:`motors`. + initial_baudrate (int | None, optional): Current baud-rate (skips scanning when provided). + Defaults to None. + initial_id (int | None, optional): Current ID (skips scanning when provided). Defaults to None. + + Raises: + RuntimeError: The motor could not be found or its model number + does not match the expected one. + ConnectionError: Communication with the motor failed. + """ + if not self.is_connected: + self._connect(handshake=False) + + if initial_baudrate is None: + initial_baudrate, initial_id = self._find_single_motor(motor) + + if initial_id is None: + _, initial_id = self._find_single_motor(motor, initial_baudrate) + + model = self.motors[motor].model + target_id = self.motors[motor].id + self.set_baudrate(initial_baudrate) + self._disable_torque(initial_id, model) + + # Set ID + addr, length = get_address(self.model_ctrl_table, model, "ID") + self._write(addr, length, initial_id, target_id) + + # Set Baudrate + addr, length = get_address(self.model_ctrl_table, model, "Baud_Rate") + baudrate_value = self.model_baudrate_table[model][self.default_baudrate] + self._write(addr, length, target_id, baudrate_value) + + self.set_baudrate(self.default_baudrate) + + @abc.abstractmethod + def _find_single_motor(self, motor: str, initial_baudrate: int | None) -> tuple[int, int]: + pass + + @abc.abstractmethod + def configure_motors(self) -> None: + """Write implementation-specific recommended settings to every motor. + + Typical changes include shortening the return delay, increasing + acceleration limits or disabling safety locks. + """ + pass + + @abc.abstractmethod + def disable_torque(self, motors: int | str | list[str] | None = None, num_retry: int = 0) -> None: + """Disable torque on selected motors. + + Disabling Torque allows to write to the motors' permanent memory area (EPROM/EEPROM). + + Args: + motors (int | str | list[str] | None, optional): Target motors. Accepts a motor name, an ID, a + list of names or `None` to affect every registered motor. Defaults to `None`. + num_retry (int, optional): Number of additional retry attempts on communication failure. + Defaults to 0. + """ + pass + + @abc.abstractmethod + def _disable_torque(self, motor: int, model: str, num_retry: int = 0) -> None: + pass + + @abc.abstractmethod + def enable_torque(self, motors: str | list[str] | None = None, num_retry: int = 0) -> None: + """Enable torque on selected motors. + + Args: + motor (int): Same semantics as :pymeth:`disable_torque`. Defaults to `None`. + num_retry (int, optional): Number of additional retry attempts on communication failure. + Defaults to 0. + """ + pass + + @contextmanager + def torque_disabled(self): + """Context-manager that guarantees torque is re-enabled. + + This helper is useful to temporarily disable torque when configuring motors. + + Examples: + >>> with bus.torque_disabled(): + ... # Safe operations here + ... pass + """ + self.disable_torque() + try: + yield + finally: + self.enable_torque() + + def set_timeout(self, timeout_ms: int | None = None): + """Change the packet timeout used by the SDK. + + Args: + timeout_ms (int | None, optional): Timeout in *milliseconds*. If `None` (default) the method falls + back to :pyattr:`default_timeout`. + """ + timeout_ms = timeout_ms if timeout_ms is not None else self.default_timeout + self.port_handler.setPacketTimeoutMillis(timeout_ms) + + def get_baudrate(self) -> int: + """Return the current baud-rate configured on the port. + + Returns: + int: Baud-rate in bits / second. + """ + return self.port_handler.getBaudRate() + + def set_baudrate(self, baudrate: int) -> None: + """Set a new UART baud-rate on the port. + + Args: + baudrate (int): Desired baud-rate in bits / second. + + Raises: + RuntimeError: The SDK failed to apply the change. + """ + present_bus_baudrate = self.port_handler.getBaudRate() + if present_bus_baudrate != baudrate: + logger.info(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.") + self.port_handler.setBaudRate(baudrate) + + if self.port_handler.getBaudRate() != baudrate: + raise RuntimeError("Failed to write bus baud rate.") + + @property + @abc.abstractmethod + def is_calibrated(self) -> bool: + """bool: ``True`` if the cached calibration matches the motors.""" + pass + + @abc.abstractmethod + def read_calibration(self) -> dict[str, MotorCalibration]: + """Read calibration parameters from the motors. + + Returns: + dict[str, MotorCalibration]: Mapping *motor name → calibration*. + """ + pass + + @abc.abstractmethod + def write_calibration(self, calibration_dict: dict[str, MotorCalibration]) -> None: + """Write calibration parameters to the motors and cache them. + + Args: + calibration_dict (dict[str, MotorCalibration]): Calibration obtained from + :pymeth:`read_calibration` or crafted by the user. + """ + pass + + def reset_calibration(self, motors: NameOrID | list[NameOrID] | None = None) -> None: + """Restore factory calibration for the selected motors. + + Homing offset is set to ``0`` and min/max position limits are set to the full usable range. + The in-memory :pyattr:`calibration` is cleared. + + Args: + motors (NameOrID | list[NameOrID] | None, optional): Selection of motors. `None` (default) + resets every motor. + """ + if motors is None: + motors = list(self.motors) + elif isinstance(motors, (str, int)): + motors = [motors] + elif not isinstance(motors, list): + raise TypeError(motors) + + for motor in motors: + model = self._get_motor_model(motor) + max_res = self.model_resolution_table[model] - 1 + self.write("Homing_Offset", motor, 0, normalize=False) + self.write("Min_Position_Limit", motor, 0, normalize=False) + self.write("Max_Position_Limit", motor, max_res, normalize=False) + + self.calibration = {} + + def set_half_turn_homings(self, motors: NameOrID | list[NameOrID] | None = None) -> dict[NameOrID, Value]: + """Centre each motor range around its current position. + + The function computes and writes a homing offset such that the present position becomes exactly one + half-turn (e.g. `2047` on a 12-bit encoder). + + Args: + motors (NameOrID | list[NameOrID] | None, optional): Motors to adjust. Defaults to all motors (`None`). + + Returns: + dict[NameOrID, Value]: Mapping *motor → written homing offset*. + """ + if motors is None: + motors = list(self.motors) + elif isinstance(motors, (str, int)): + motors = [motors] + elif not isinstance(motors, list): + raise TypeError(motors) + + self.reset_calibration(motors) + actual_positions = self.sync_read("Present_Position", motors, normalize=False) + homing_offsets = self._get_half_turn_homings(actual_positions) + for motor, offset in homing_offsets.items(): + self.write("Homing_Offset", motor, offset) + + return homing_offsets + + @abc.abstractmethod + def _get_half_turn_homings(self, positions: dict[NameOrID, Value]) -> dict[NameOrID, Value]: + pass + + def record_ranges_of_motion( + self, motors: NameOrID | list[NameOrID] | None = None, display_values: bool = True + ) -> tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: + """Interactively record the min/max encoder values of each motor. + + Move the joints by hand (with torque disabled) while the method streams live positions. Press + :kbd:`Enter` to finish. + + Args: + motors (NameOrID | list[NameOrID] | None, optional): Motors to record. + Defaults to every motor (`None`). + display_values (bool, optional): When `True` (default) a live table is printed to the console. + + Returns: + tuple[dict[NameOrID, Value], dict[NameOrID, Value]]: Two dictionaries *mins* and *maxes* with the + extreme values observed for each motor. + """ + if motors is None: + motors = list(self.motors) + elif isinstance(motors, (str, int)): + motors = [motors] + elif not isinstance(motors, list): + raise TypeError(motors) + + start_positions = self.sync_read("Present_Position", motors, normalize=False) + mins = start_positions.copy() + maxes = start_positions.copy() + + user_pressed_enter = False + while not user_pressed_enter: + positions = self.sync_read("Present_Position", motors, normalize=False) + mins = {motor: min(positions[motor], min_) for motor, min_ in mins.items()} + maxes = {motor: max(positions[motor], max_) for motor, max_ in maxes.items()} + + if display_values: + print("\n-------------------------------------------") + print(f"{'NAME':<15} | {'MIN':>6} | {'POS':>6} | {'MAX':>6}") + for motor in motors: + print(f"{motor:<15} | {mins[motor]:>6} | {positions[motor]:>6} | {maxes[motor]:>6}") + + if enter_pressed(): + user_pressed_enter = True + + if display_values and not user_pressed_enter: + # Move cursor up to overwrite the previous output + move_cursor_up(len(motors) + 3) + + same_min_max = [motor for motor in motors if mins[motor] == maxes[motor]] + if same_min_max: + raise ValueError(f"Some motors have the same min and max values:\n{pformat(same_min_max)}") + + return mins, maxes + + def _normalize(self, ids_values: dict[int, int]) -> dict[int, float]: + if not self.calibration: + raise RuntimeError(f"{self} has no calibration registered.") + + normalized_values = {} + for id_, val in ids_values.items(): + motor = self._id_to_name(id_) + min_ = self.calibration[motor].range_min + max_ = self.calibration[motor].range_max + drive_mode = self.apply_drive_mode and self.calibration[motor].drive_mode + if max_ == min_: + raise ValueError(f"Invalid calibration for motor '{motor}': min and max are equal.") + + bounded_val = min(max_, max(min_, val)) + if self.motors[motor].norm_mode is MotorNormMode.RANGE_M100_100: + norm = (((bounded_val - min_) / (max_ - min_)) * 200) - 100 + normalized_values[id_] = -norm if drive_mode else norm + elif self.motors[motor].norm_mode is MotorNormMode.RANGE_0_100: + norm = ((bounded_val - min_) / (max_ - min_)) * 100 + normalized_values[id_] = 100 - norm if drive_mode else norm + elif self.motors[motor].norm_mode is MotorNormMode.DEGREES: + mid = (min_ + max_) / 2 + max_res = self.model_resolution_table[self._id_to_model(id_)] - 1 + normalized_values[id_] = (val - mid) * 360 / max_res + else: + raise NotImplementedError + + return normalized_values + + def _unnormalize(self, ids_values: dict[int, float]) -> dict[int, int]: + if not self.calibration: + raise RuntimeError(f"{self} has no calibration registered.") + + unnormalized_values = {} + for id_, val in ids_values.items(): + motor = self._id_to_name(id_) + min_ = self.calibration[motor].range_min + max_ = self.calibration[motor].range_max + drive_mode = self.apply_drive_mode and self.calibration[motor].drive_mode + if max_ == min_: + raise ValueError(f"Invalid calibration for motor '{motor}': min and max are equal.") + + if self.motors[motor].norm_mode is MotorNormMode.RANGE_M100_100: + val = -val if drive_mode else val + bounded_val = min(100.0, max(-100.0, val)) + unnormalized_values[id_] = int(((bounded_val + 100) / 200) * (max_ - min_) + min_) + elif self.motors[motor].norm_mode is MotorNormMode.RANGE_0_100: + val = 100 - val if drive_mode else val + bounded_val = min(100.0, max(0.0, val)) + unnormalized_values[id_] = int((bounded_val / 100) * (max_ - min_) + min_) + elif self.motors[motor].norm_mode is MotorNormMode.DEGREES: + mid = (min_ + max_) / 2 + max_res = self.model_resolution_table[self._id_to_model(id_)] - 1 + unnormalized_values[id_] = int((val * max_res / 360) + mid) + else: + raise NotImplementedError + + return unnormalized_values + + @abc.abstractmethod + def _encode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]: + pass + + @abc.abstractmethod + def _decode_sign(self, data_name: str, ids_values: dict[int, int]) -> dict[int, int]: + pass + + def _serialize_data(self, value: int, length: int) -> list[int]: + """ + Converts an unsigned integer value into a list of byte-sized integers to be sent via a communication + protocol. Depending on the protocol, split values can be in big-endian or little-endian order. + + Supported data length for both Feetech and Dynamixel: + - 1 (for values 0 to 255) + - 2 (for values 0 to 65,535) + - 4 (for values 0 to 4,294,967,295) + """ + if value < 0: + raise ValueError(f"Negative values are not allowed: {value}") + + max_value = {1: 0xFF, 2: 0xFFFF, 4: 0xFFFFFFFF}.get(length) + if max_value is None: + raise NotImplementedError(f"Unsupported byte size: {length}. Expected [1, 2, 4].") + + if value > max_value: + raise ValueError(f"Value {value} exceeds the maximum for {length} bytes ({max_value}).") + + return self._split_into_byte_chunks(value, length) + + @abc.abstractmethod + def _split_into_byte_chunks(self, value: int, length: int) -> list[int]: + """Convert an integer into a list of byte-sized integers.""" + pass + + def ping(self, motor: NameOrID, num_retry: int = 0, raise_on_error: bool = False) -> int | None: + """Ping a single motor and return its model number. + + Args: + motor (NameOrID): Target motor (name or ID). + num_retry (int, optional): Extra attempts before giving up. Defaults to `0`. + raise_on_error (bool, optional): If `True` communication errors raise exceptions instead of + returning `None`. Defaults to `False`. + + Returns: + int | None: Motor model number or `None` on failure. + """ + id_ = self._get_motor_id(motor) + for n_try in range(1 + num_retry): + model_number, comm, error = self.packet_handler.ping(self.port_handler, id_) + if self._is_comm_success(comm): + break + logger.debug(f"ping failed for {id_=}: {n_try=} got {comm=} {error=}") + + if not self._is_comm_success(comm): + if raise_on_error: + raise ConnectionError(self.packet_handler.getTxRxResult(comm)) + else: + return + if self._is_error(error): + if raise_on_error: + raise RuntimeError(self.packet_handler.getRxPacketError(error)) + else: + return + + return model_number + + @abc.abstractmethod + def broadcast_ping(self, num_retry: int = 0, raise_on_error: bool = False) -> dict[int, int] | None: + """Ping every ID on the bus using the broadcast address. + + Args: + num_retry (int, optional): Retry attempts. Defaults to `0`. + raise_on_error (bool, optional): When `True` failures raise an exception instead of returning + `None`. Defaults to `False`. + + Returns: + dict[int, int] | None: Mapping *id → model number* or `None` if the call failed. + """ + pass + + def read( + self, + data_name: str, + motor: str, + *, + normalize: bool = True, + num_retry: int = 0, + ) -> Value: + """Read a register from a motor. + + Args: + data_name (str): Control-table key (e.g. `"Present_Position"`). + motor (str): Motor name. + normalize (bool, optional): When `True` (default) scale the value to a user-friendly range as + defined by the calibration. + num_retry (int, optional): Retry attempts. Defaults to `0`. + + Returns: + Value: Raw or normalised value depending on *normalize*. + """ + if not self.is_connected: + raise DeviceNotConnectedError( + f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." + ) + + id_ = self.motors[motor].id + model = self.motors[motor].model + addr, length = get_address(self.model_ctrl_table, model, data_name) + + err_msg = f"Failed to read '{data_name}' on {id_=} after {num_retry + 1} tries." + value, _, _ = self._read(addr, length, id_, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) + + id_value = self._decode_sign(data_name, {id_: value}) + + if normalize and data_name in self.normalized_data: + id_value = self._normalize(id_value) + + return id_value[id_] + + def _read( + self, + address: int, + length: int, + motor_id: int, + *, + num_retry: int = 0, + raise_on_error: bool = True, + err_msg: str = "", + ) -> tuple[int, int]: + if length == 1: + read_fn = self.packet_handler.read1ByteTxRx + elif length == 2: + read_fn = self.packet_handler.read2ByteTxRx + elif length == 4: + read_fn = self.packet_handler.read4ByteTxRx + else: + raise ValueError(length) + + for n_try in range(1 + num_retry): + value, comm, error = read_fn(self.port_handler, motor_id, address) + if self._is_comm_success(comm): + break + logger.debug( + f"Failed to read @{address=} ({length=}) on {motor_id=} ({n_try=}): " + + self.packet_handler.getTxRxResult(comm) + ) + + if not self._is_comm_success(comm) and raise_on_error: + raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}") + elif self._is_error(error) and raise_on_error: + raise RuntimeError(f"{err_msg} {self.packet_handler.getRxPacketError(error)}") + + return value, comm, error + + def write( + self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0 + ) -> None: + """Write a value to a single motor's register. + + Contrary to :pymeth:`sync_write`, this expects a response status packet emitted by the motor, which + provides a guarantee that the value was written to the register successfully. In consequence, it is + slower than :pymeth:`sync_write` but it is more reliable. It should typically be used when configuring + motors. + + Args: + data_name (str): Register name. + motor (str): Motor name. + value (Value): Value to write. If *normalize* is `True` the value is first converted to raw + units. + normalize (bool, optional): Enable or disable normalisation. Defaults to `True`. + num_retry (int, optional): Retry attempts. Defaults to `0`. + """ + if not self.is_connected: + raise DeviceNotConnectedError( + f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." + ) + + id_ = self.motors[motor].id + model = self.motors[motor].model + addr, length = get_address(self.model_ctrl_table, model, data_name) + + if normalize and data_name in self.normalized_data: + value = self._unnormalize({id_: value})[id_] + + value = self._encode_sign(data_name, {id_: value})[id_] + + err_msg = f"Failed to write '{data_name}' on {id_=} with '{value}' after {num_retry + 1} tries." + self._write(addr, length, id_, value, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) + + def _write( + self, + addr: int, + length: int, + motor_id: int, + value: int, + *, + num_retry: int = 0, + raise_on_error: bool = True, + err_msg: str = "", + ) -> tuple[int, int]: + data = self._serialize_data(value, length) + for n_try in range(1 + num_retry): + comm, error = self.packet_handler.writeTxRx(self.port_handler, motor_id, addr, length, data) + if self._is_comm_success(comm): + break + logger.debug( + f"Failed to sync write @{addr=} ({length=}) on id={motor_id} with {value=} ({n_try=}): " + + self.packet_handler.getTxRxResult(comm) + ) + + if not self._is_comm_success(comm) and raise_on_error: + raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}") + elif self._is_error(error) and raise_on_error: + raise RuntimeError(f"{err_msg} {self.packet_handler.getRxPacketError(error)}") + + return comm, error + + def sync_read( + self, + data_name: str, + motors: str | list[str] | None = None, + *, + normalize: bool = True, + num_retry: int = 0, + ) -> dict[str, Value]: + """Read the same register from several motors at once. + + Args: + data_name (str): Register name. + motors (str | list[str] | None, optional): Motors to query. `None` (default) reads every motor. + normalize (bool, optional): Normalisation flag. Defaults to `True`. + num_retry (int, optional): Retry attempts. Defaults to `0`. + + Returns: + dict[str, Value]: Mapping *motor name → value*. + """ + if not self.is_connected: + raise DeviceNotConnectedError( + f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." + ) + + self._assert_protocol_is_compatible("sync_read") + + names = self._get_motors_list(motors) + ids = [self.motors[motor].id for motor in names] + models = [self.motors[motor].model for motor in names] + + if self._has_different_ctrl_tables: + assert_same_address(self.model_ctrl_table, models, data_name) + + model = next(iter(models)) + addr, length = get_address(self.model_ctrl_table, model, data_name) + + err_msg = f"Failed to sync read '{data_name}' on {ids=} after {num_retry + 1} tries." + ids_values, _ = self._sync_read( + addr, length, ids, num_retry=num_retry, raise_on_error=True, err_msg=err_msg + ) + + ids_values = self._decode_sign(data_name, ids_values) + + if normalize and data_name in self.normalized_data: + ids_values = self._normalize(ids_values) + + return {self._id_to_name(id_): value for id_, value in ids_values.items()} + + def _sync_read( + self, + addr: int, + length: int, + motor_ids: list[int], + *, + num_retry: int = 0, + raise_on_error: bool = True, + err_msg: str = "", + ) -> tuple[dict[int, int], int]: + self._setup_sync_reader(motor_ids, addr, length) + for n_try in range(1 + num_retry): + comm = self.sync_reader.txRxPacket() + if self._is_comm_success(comm): + break + logger.debug( + f"Failed to sync read @{addr=} ({length=}) on {motor_ids=} ({n_try=}): " + + self.packet_handler.getTxRxResult(comm) + ) + + if not self._is_comm_success(comm) and raise_on_error: + raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}") + + values = {id_: self.sync_reader.getData(id_, addr, length) for id_ in motor_ids} + return values, comm + + def _setup_sync_reader(self, motor_ids: list[int], addr: int, length: int) -> None: + self.sync_reader.clearParam() + self.sync_reader.start_address = addr + self.sync_reader.data_length = length + for id_ in motor_ids: + self.sync_reader.addParam(id_) + + # TODO(aliberts, pkooij): Implementing something like this could get even much faster read times if need be. + # Would have to handle the logic of checking if a packet has been sent previously though but doable. + # This could be at the cost of increase latency between the moment the data is produced by the motors and + # the moment it is used by a policy. + # def _async_read(self, motor_ids: list[int], address: int, length: int): + # if self.sync_reader.start_address != address or self.sync_reader.data_length != length or ...: + # self._setup_sync_reader(motor_ids, address, length) + # else: + # self.sync_reader.rxPacket() + # self.sync_reader.txPacket() + + # for id_ in motor_ids: + # value = self.sync_reader.getData(id_, address, length) + + def sync_write( + self, + data_name: str, + values: Value | dict[str, Value], + *, + normalize: bool = True, + num_retry: int = 0, + ) -> None: + """Write the same register on multiple motors. + + Contrary to :pymeth:`write`, this *does not* expects a response status packet emitted by the motor, which + can allow for lost packets. It is faster than :pymeth:`write` and should typically be used when + frequency matters and losing some packets is acceptable (e.g. teleoperation loops). + + Args: + data_name (str): Register name. + values (Value | dict[str, Value]): Either a single value (applied to every motor) or a mapping + *motor name → value*. + normalize (bool, optional): If `True` (default) convert values from the user range to raw units. + num_retry (int, optional): Retry attempts. Defaults to `0`. + """ + if not self.is_connected: + raise DeviceNotConnectedError( + f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`." + ) + + ids_values = self._get_ids_values_dict(values) + models = [self._id_to_model(id_) for id_ in ids_values] + if self._has_different_ctrl_tables: + assert_same_address(self.model_ctrl_table, models, data_name) + + model = next(iter(models)) + addr, length = get_address(self.model_ctrl_table, model, data_name) + + if normalize and data_name in self.normalized_data: + ids_values = self._unnormalize(ids_values) + + ids_values = self._encode_sign(data_name, ids_values) + + err_msg = f"Failed to sync write '{data_name}' with {ids_values=} after {num_retry + 1} tries." + self._sync_write(addr, length, ids_values, num_retry=num_retry, raise_on_error=True, err_msg=err_msg) + + def _sync_write( + self, + addr: int, + length: int, + ids_values: dict[int, int], + num_retry: int = 0, + raise_on_error: bool = True, + err_msg: str = "", + ) -> int: + self._setup_sync_writer(ids_values, addr, length) + for n_try in range(1 + num_retry): + comm = self.sync_writer.txPacket() + if self._is_comm_success(comm): + break + logger.debug( + f"Failed to sync write @{addr=} ({length=}) with {ids_values=} ({n_try=}): " + + self.packet_handler.getTxRxResult(comm) + ) + + if not self._is_comm_success(comm) and raise_on_error: + raise ConnectionError(f"{err_msg} {self.packet_handler.getTxRxResult(comm)}") + + return comm + + def _setup_sync_writer(self, ids_values: dict[int, int], addr: int, length: int) -> None: + self.sync_writer.clearParam() + self.sync_writer.start_address = addr + self.sync_writer.data_length = length + for id_, value in ids_values.items(): + data = self._serialize_data(value, length) + self.sync_writer.addParam(id_, data) diff --git a/lerobot/common/optim/optimizers.py b/lerobot/common/optim/optimizers.py index 0cf4124ce..903434f59 100644 --- a/lerobot/common/optim/optimizers.py +++ b/lerobot/common/optim/optimizers.py @@ -14,8 +14,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import abc -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from pathlib import Path +from typing import Any import draccus import torch @@ -44,7 +45,16 @@ class OptimizerConfig(draccus.ChoiceRegistry, abc.ABC): return "adam" @abc.abstractmethod - def build(self) -> torch.optim.Optimizer: + def build(self) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]: + """ + Build the optimizer. It can be a single optimizer or a dictionary of optimizers. + NOTE: Multiple optimizers are useful when you have different models to optimize. + For example, you can have one optimizer for the policy and another one for the value function + in reinforcement learning settings. + + Returns: + The optimizer or a dictionary of optimizers. + """ raise NotImplementedError @@ -94,7 +104,76 @@ class SGDConfig(OptimizerConfig): return torch.optim.SGD(params, **kwargs) -def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None: +@OptimizerConfig.register_subclass("multi_adam") +@dataclass +class MultiAdamConfig(OptimizerConfig): + """Configuration for multiple Adam optimizers with different parameter groups. + + This creates a dictionary of Adam optimizers, each with its own hyperparameters. + + Args: + lr: Default learning rate (used if not specified for a group) + weight_decay: Default weight decay (used if not specified for a group) + optimizer_groups: Dictionary mapping parameter group names to their hyperparameters + grad_clip_norm: Gradient clipping norm + """ + + lr: float = 1e-3 + weight_decay: float = 0.0 + grad_clip_norm: float = 10.0 + optimizer_groups: dict[str, dict[str, Any]] = field(default_factory=dict) + + def build(self, params_dict: dict[str, list]) -> dict[str, torch.optim.Optimizer]: + """Build multiple Adam optimizers. + + Args: + params_dict: Dictionary mapping parameter group names to lists of parameters + The keys should match the keys in optimizer_groups + + Returns: + Dictionary mapping parameter group names to their optimizers + """ + optimizers = {} + + for name, params in params_dict.items(): + # Get group-specific hyperparameters or use defaults + group_config = self.optimizer_groups.get(name, {}) + + # Create optimizer with merged parameters (defaults + group-specific) + optimizer_kwargs = { + "lr": group_config.get("lr", self.lr), + "betas": group_config.get("betas", (0.9, 0.999)), + "eps": group_config.get("eps", 1e-5), + "weight_decay": group_config.get("weight_decay", self.weight_decay), + } + + optimizers[name] = torch.optim.Adam(params, **optimizer_kwargs) + + return optimizers + + +def save_optimizer_state( + optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path +) -> None: + """Save optimizer state to disk. + + Args: + optimizer: Either a single optimizer or a dictionary of optimizers. + save_dir: Directory to save the optimizer state. + """ + if isinstance(optimizer, dict): + # Handle dictionary of optimizers + for name, opt in optimizer.items(): + optimizer_dir = save_dir / name + optimizer_dir.mkdir(exist_ok=True, parents=True) + _save_single_optimizer_state(opt, optimizer_dir) + else: + # Handle single optimizer + _save_single_optimizer_state(optimizer, save_dir) + + +def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None: + """Save a single optimizer's state to disk.""" state = optimizer.state_dict() param_groups = state.pop("param_groups") flat_state = flatten_dict(state) @@ -102,11 +181,44 @@ def save_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> No write_json(param_groups, save_dir / OPTIMIZER_PARAM_GROUPS) -def load_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer: +def load_optimizer_state( + optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path +) -> torch.optim.Optimizer | dict[str, torch.optim.Optimizer]: + """Load optimizer state from disk. + + Args: + optimizer: Either a single optimizer or a dictionary of optimizers. + save_dir: Directory to load the optimizer state from. + + Returns: + The updated optimizer(s) with loaded state. + """ + if isinstance(optimizer, dict): + # Handle dictionary of optimizers + loaded_optimizers = {} + for name, opt in optimizer.items(): + optimizer_dir = save_dir / name + if optimizer_dir.exists(): + loaded_optimizers[name] = _load_single_optimizer_state(opt, optimizer_dir) + else: + loaded_optimizers[name] = opt + return loaded_optimizers + else: + # Handle single optimizer + return _load_single_optimizer_state(optimizer, save_dir) + + +def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> torch.optim.Optimizer: + """Load a single optimizer's state from disk.""" current_state_dict = optimizer.state_dict() flat_state = load_file(save_dir / OPTIMIZER_STATE) state = unflatten_dict(flat_state) - loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}} + + # Handle case where 'state' key might not exist (for newly created optimizers) + if "state" in state: + loaded_state_dict = {"state": {int(k): v for k, v in state["state"].items()}} + else: + loaded_state_dict = {"state": {}} if "param_groups" in current_state_dict: param_groups = deserialize_json_into_object( diff --git a/lerobot/common/policies/__init__.py b/lerobot/common/policies/__init__.py index b73ba5f4e..9cb0f6234 100644 --- a/lerobot/common/policies/__init__.py +++ b/lerobot/common/policies/__init__.py @@ -15,5 +15,6 @@ from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .pi0.configuration_pi0 import PI0Config as PI0Config +from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index e73c65fe9..c8841f06b 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -81,7 +81,7 @@ class DiffusionConfig(PreTrainedConfig): n_groups: Number of groups used in the group norm of the Unet's convolutional blocks. diffusion_step_embed_dim: The Unet is conditioned on the diffusion timestep via a small non-linear network. This is the output dimension of that network, i.e., the embedding dimension. - use_film_scale_modulation: FiLM (https://arxiv.org/abs/1709.07871) is used for the Unet conditioning. + use_film_scale_modulation: FiLM (https://huggingface.co/papers/1709.07871) is used for the Unet conditioning. Bias modulation is used be default, while this parameter indicates whether to also use scale modulation. noise_scheduler_type: Name of the noise scheduler to use. Supported options: ["DDPM", "DDIM"]. diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 9ecadcb05..038136d07 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -33,7 +33,7 @@ from diffusers.schedulers.scheduling_ddim import DDIMScheduler from diffusers.schedulers.scheduling_ddpm import DDPMScheduler from torch import Tensor, nn -from lerobot.common.constants import OBS_ENV, OBS_ROBOT +from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGES, OBS_STATE from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy @@ -48,7 +48,7 @@ from lerobot.common.policies.utils import ( class DiffusionPolicy(PreTrainedPolicy): """ Diffusion Policy as per "Diffusion Policy: Visuomotor Policy Learning via Action Diffusion" - (paper: https://arxiv.org/abs/2303.04137, code: https://github.com/real-stanford/diffusion_policy). + (paper: https://huggingface.co/papers/2303.04137, code: https://github.com/real-stanford/diffusion_policy). """ config_class = DiffusionConfig @@ -99,6 +99,18 @@ class DiffusionPolicy(PreTrainedPolicy): if self.config.env_state_feature: self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps) + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + # stack n latest observations from the queue + batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} + actions = self.diffusion.generate_actions(batch) + + # TODO(rcadene): make above methods return output dictionary? + actions = self.unnormalize_outputs({ACTION: actions})[ACTION] + + return actions + @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. @@ -124,23 +136,15 @@ class DiffusionPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack( - [batch[key] for key in self.config.image_features], dim=-4 - ) + batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) # Note: It's important that this happens after stacking the images into a single key. self._queues = populate_queues(self._queues, batch) - if len(self._queues["action"]) == 0: - # stack n latest observations from the queue - batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} - actions = self.diffusion.generate_actions(batch) + if len(self._queues[ACTION]) == 0: + actions = self.predict_action_chunk(batch) + self._queues[ACTION].extend(actions.transpose(0, 1)) - # TODO(rcadene): make above methods return output dictionary? - actions = self.unnormalize_outputs({"action": actions})["action"] - - self._queues["action"].extend(actions.transpose(0, 1)) - - action = self._queues["action"].popleft() + action = self._queues[ACTION].popleft() return action def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, None]: @@ -148,9 +152,7 @@ class DiffusionPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack( - [batch[key] for key in self.config.image_features], dim=-4 - ) + batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) batch = self.normalize_targets(batch) loss = self.diffusion.compute_loss(batch) # no output_dict so returning None @@ -238,8 +240,8 @@ class DiffusionModel(nn.Module): def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor: """Encode image features and concatenate them all together along with the state vector.""" - batch_size, n_obs_steps = batch[OBS_ROBOT].shape[:2] - global_cond_feats = [batch[OBS_ROBOT]] + batch_size, n_obs_steps = batch[OBS_STATE].shape[:2] + global_cond_feats = [batch[OBS_STATE]] # Extract image features. if self.config.image_features: if self.config.use_separate_rgb_encoder_per_camera: @@ -269,7 +271,7 @@ class DiffusionModel(nn.Module): global_cond_feats.append(img_features) if self.config.env_state_feature: - global_cond_feats.append(batch[OBS_ENV]) + global_cond_feats.append(batch[OBS_ENV_STATE]) # Concatenate features then flatten to (B, global_cond_dim). return torch.cat(global_cond_feats, dim=-1).flatten(start_dim=1) @@ -370,7 +372,7 @@ class DiffusionModel(nn.Module): class SpatialSoftmax(nn.Module): """ Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al. - (https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation. + (https://huggingface.co/papers/1509.06113). A minimal port of the robomimic implementation. At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass" of activations of each channel, i.e., keypoints in the image space for the policy to focus on. @@ -728,7 +730,7 @@ class DiffusionConditionalResidualBlock1d(nn.Module): self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups) - # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale. + # FiLM modulation (https://huggingface.co/papers/1709.07871) outputs per-channel bias and (maybe) scale. cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels)) diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index b3255ec10..9cc94b929 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -151,6 +151,7 @@ class Normalize(nn.Module): # TODO(rcadene): should we remove torch.no_grad? @torch.no_grad def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + # TODO: Remove this shallow copy batch = dict(batch) # shallow copy avoids mutating the input batch for key, ft in self.features.items(): if key not in batch: @@ -252,3 +253,168 @@ class Unnormalize(nn.Module): else: raise ValueError(norm_mode) return batch + + +# TODO (azouitine): We should replace all normalization on the policies with register_buffer normalization +# and remove the `Normalize` and `Unnormalize` classes. +def _initialize_stats_buffers( + module: nn.Module, + features: dict[str, PolicyFeature], + norm_map: dict[str, NormalizationMode], + stats: dict[str, dict[str, Tensor]] | None = None, +) -> None: + """Register statistics buffers (mean/std or min/max) on the given *module*. + + The logic matches the previous constructors of `NormalizeBuffer` and `UnnormalizeBuffer`, + but is factored out so it can be reused by both classes and stay in sync. + """ + for key, ft in features.items(): + norm_mode = norm_map.get(ft.type, NormalizationMode.IDENTITY) + if norm_mode is NormalizationMode.IDENTITY: + continue + + shape: tuple[int, ...] = tuple(ft.shape) + if ft.type is FeatureType.VISUAL: + # reduce spatial dimensions, keep channel dimension only + c, *_ = shape + shape = (c, 1, 1) + + prefix = key.replace(".", "_") + + if norm_mode is NormalizationMode.MEAN_STD: + mean = torch.full(shape, torch.inf, dtype=torch.float32) + std = torch.full(shape, torch.inf, dtype=torch.float32) + + if stats and key in stats and "mean" in stats[key] and "std" in stats[key]: + mean_data = stats[key]["mean"] + std_data = stats[key]["std"] + if isinstance(mean_data, torch.Tensor): + # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated + # tensors anywhere (for example, when we use the same stats for normalization and + # unnormalization). See the logic here + # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. + mean = mean_data.clone().to(dtype=torch.float32) + std = std_data.clone().to(dtype=torch.float32) + else: + raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).") + + module.register_buffer(f"{prefix}_mean", mean) + module.register_buffer(f"{prefix}_std", std) + continue + + if norm_mode is NormalizationMode.MIN_MAX: + min_val = torch.full(shape, torch.inf, dtype=torch.float32) + max_val = torch.full(shape, torch.inf, dtype=torch.float32) + + if stats and key in stats and "min" in stats[key] and "max" in stats[key]: + min_data = stats[key]["min"] + max_data = stats[key]["max"] + if isinstance(min_data, torch.Tensor): + min_val = min_data.clone().to(dtype=torch.float32) + max_val = max_data.clone().to(dtype=torch.float32) + else: + raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).") + + module.register_buffer(f"{prefix}_min", min_val) + module.register_buffer(f"{prefix}_max", max_val) + continue + + raise ValueError(norm_mode) + + +class NormalizeBuffer(nn.Module): + """Same as `Normalize` but statistics are stored as registered buffers rather than parameters.""" + + def __init__( + self, + features: dict[str, PolicyFeature], + norm_map: dict[str, NormalizationMode], + stats: dict[str, dict[str, Tensor]] | None = None, + ): + super().__init__() + self.features = features + self.norm_map = norm_map + + _initialize_stats_buffers(self, features, norm_map, stats) + + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + batch = dict(batch) + for key, ft in self.features.items(): + if key not in batch: + continue + + norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) + if norm_mode is NormalizationMode.IDENTITY: + continue + + prefix = key.replace(".", "_") + + if norm_mode is NormalizationMode.MEAN_STD: + mean = getattr(self, f"{prefix}_mean") + std = getattr(self, f"{prefix}_std") + assert not torch.isinf(mean).any(), _no_stats_error_str("mean") + assert not torch.isinf(std).any(), _no_stats_error_str("std") + batch[key] = (batch[key] - mean) / (std + 1e-8) + continue + + if norm_mode is NormalizationMode.MIN_MAX: + min_val = getattr(self, f"{prefix}_min") + max_val = getattr(self, f"{prefix}_max") + assert not torch.isinf(min_val).any(), _no_stats_error_str("min") + assert not torch.isinf(max_val).any(), _no_stats_error_str("max") + batch[key] = (batch[key] - min_val) / (max_val - min_val + 1e-8) + batch[key] = batch[key] * 2 - 1 + continue + + raise ValueError(norm_mode) + + return batch + + +class UnnormalizeBuffer(nn.Module): + """Inverse operation of `NormalizeBuffer`. Uses registered buffers for statistics.""" + + def __init__( + self, + features: dict[str, PolicyFeature], + norm_map: dict[str, NormalizationMode], + stats: dict[str, dict[str, Tensor]] | None = None, + ): + super().__init__() + self.features = features + self.norm_map = norm_map + + _initialize_stats_buffers(self, features, norm_map, stats) + + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + # batch = dict(batch) + for key, ft in self.features.items(): + if key not in batch: + continue + + norm_mode = self.norm_map.get(ft.type, NormalizationMode.IDENTITY) + if norm_mode is NormalizationMode.IDENTITY: + continue + + prefix = key.replace(".", "_") + + if norm_mode is NormalizationMode.MEAN_STD: + mean = getattr(self, f"{prefix}_mean") + std = getattr(self, f"{prefix}_std") + assert not torch.isinf(mean).any(), _no_stats_error_str("mean") + assert not torch.isinf(std).any(), _no_stats_error_str("std") + batch[key] = batch[key] * std + mean + continue + + if norm_mode is NormalizationMode.MIN_MAX: + min_val = getattr(self, f"{prefix}_min") + max_val = getattr(self, f"{prefix}_max") + assert not torch.isinf(min_val).any(), _no_stats_error_str("min") + assert not torch.isinf(max_val).any(), _no_stats_error_str("max") + batch[key] = (batch[key] + 1) / 2 + batch[key] = batch[key] * (max_val - min_val) + min_val + continue + + raise ValueError(norm_mode) + + return batch diff --git a/lerobot/common/policies/pi0/modeling_pi0.py b/lerobot/common/policies/pi0/modeling_pi0.py index 7599fa635..97e66a272 100644 --- a/lerobot/common/policies/pi0/modeling_pi0.py +++ b/lerobot/common/policies/pi0/modeling_pi0.py @@ -57,7 +57,7 @@ import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn from transformers import AutoTokenizer -from lerobot.common.constants import ACTION, OBS_ROBOT +from lerobot.common.constants import ACTION, OBS_STATE from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pi0.configuration_pi0 import PI0Config from lerobot.common.policies.pi0.paligemma_with_expert import ( @@ -260,6 +260,11 @@ class PI0Policy(PreTrainedPolicy): def get_optim_params(self) -> dict: return self.parameters() + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + raise NotImplementedError("Currently not implemented for PI0") + @torch.no_grad def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: """Select a single action given environment observations. @@ -271,7 +276,7 @@ class PI0Policy(PreTrainedPolicy): self.eval() if self.config.adapt_to_pi_aloha: - batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) + batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) batch = self.normalize_inputs(batch) @@ -303,7 +308,7 @@ class PI0Policy(PreTrainedPolicy): def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]: """Do a full training forward pass to compute the loss""" if self.config.adapt_to_pi_aloha: - batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) + batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) batch = self.normalize_inputs(batch) @@ -357,7 +362,7 @@ class PI0Policy(PreTrainedPolicy): if self.config.resize_imgs_with_padding is not None: img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0) - # Normalize from range [0,1] to [-1,1] as expacted by siglip + # Normalize from range [0,1] to [-1,1] as expected by siglip img = img * 2.0 - 1.0 bsize = img.shape[0] @@ -380,7 +385,7 @@ class PI0Policy(PreTrainedPolicy): def prepare_language(self, batch) -> tuple[Tensor, Tensor]: """Tokenize the text input""" - device = batch[OBS_ROBOT].device + device = batch[OBS_STATE].device tasks = batch["task"] # PaliGemma prompt has to end with a new line @@ -427,7 +432,7 @@ class PI0Policy(PreTrainedPolicy): def prepare_state(self, batch): """Pad state""" - state = pad_vector(batch[OBS_ROBOT], self.config.max_state_dim) + state = pad_vector(batch[OBS_STATE], self.config.max_state_dim) return state def prepare_action(self, batch): diff --git a/lerobot/common/policies/pi0/paligemma_with_expert.py b/lerobot/common/policies/pi0/paligemma_with_expert.py index 76e2ce600..fb5077fb2 100644 --- a/lerobot/common/policies/pi0/paligemma_with_expert.py +++ b/lerobot/common/policies/pi0/paligemma_with_expert.py @@ -216,10 +216,14 @@ class PaliGemmaWithExpertModel(PreTrainedModel): param.data = param.data.to(dtype=torch.bfloat16) def embed_image(self, image: torch.Tensor): - return self.paligemma.get_image_features(image) + # Handle different transformers versions + if hasattr(self.paligemma, "get_image_features"): + return self.paligemma.get_image_features(image) + else: + return self.paligemma.model.get_image_features(image) def embed_language_tokens(self, tokens: torch.Tensor): - return self.paligemma.language_model.model.embed_tokens(tokens) + return self.paligemma.language_model.embed_tokens(tokens) # TODO: break down this huge forward into modules or functions def forward( @@ -231,7 +235,7 @@ class PaliGemmaWithExpertModel(PreTrainedModel): use_cache: Optional[bool] = None, fill_kv_cache: Optional[bool] = None, ): - models = [self.paligemma.language_model.model, self.gemma_expert.model] + models = [self.paligemma.language_model, self.gemma_expert.model] for hidden_states in inputs_embeds: # TODO this is very inefficient diff --git a/lerobot/common/policies/pretrained.py b/lerobot/common/policies/pretrained.py index da4ef1572..bc9276d0c 100644 --- a/lerobot/common/policies/pretrained.py +++ b/lerobot/common/policies/pretrained.py @@ -14,12 +14,14 @@ import abc import logging import os +from importlib.resources import files from pathlib import Path -from typing import Type, TypeVar +from tempfile import TemporaryDirectory +from typing import List, Type, TypeVar import packaging import safetensors -from huggingface_hub import hf_hub_download +from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from huggingface_hub.errors import HfHubHTTPError from safetensors.torch import load_model as load_model_as_safetensor @@ -28,20 +30,10 @@ from torch import Tensor, nn from lerobot.common.utils.hub import HubMixin from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.train import TrainPipelineConfig T = TypeVar("T", bound="PreTrainedPolicy") -DEFAULT_POLICY_CARD = """ ---- -# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 -# Doc / guide: https://huggingface.co/docs/hub/model-cards -{{ card_data }} ---- - -This policy has been pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot): -- Docs: {{ docs_url | default("[More Information Needed]", true) }} -""" - class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): """ @@ -150,16 +142,6 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) return model - # def generate_model_card(self, *args, **kwargs) -> ModelCard: - # card = ModelCard.from_template( - # card_data=self._hub_mixin_info.model_card_data, - # template_str=self._hub_mixin_info.model_card_template, - # repo_url=self._hub_mixin_info.repo_url, - # docs_url=self._hub_mixin_info.docs_url, - # **kwargs, - # ) - # return card - @abc.abstractmethod def get_optim_params(self) -> dict: """ @@ -189,6 +171,15 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): """ raise NotImplementedError + @abc.abstractmethod + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Returns the action chunk (for action chunking policies) for a given observation, potentially in batch mode. + + Child classes using action chunking should use this method within `select_action` to form the action chunk + cached for selection. + """ + raise NotImplementedError + @abc.abstractmethod def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Return one action to run in the environment (potentially in batch mode). @@ -197,3 +188,56 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): with caching. """ raise NotImplementedError + + def push_model_to_hub( + self, + cfg: TrainPipelineConfig, + ): + api = HfApi() + repo_id = api.create_repo( + repo_id=self.config.repo_id, private=self.config.private, exist_ok=True + ).repo_id + + # Push the files to the repo in a single commit + with TemporaryDirectory(ignore_cleanup_errors=True) as tmp: + saved_path = Path(tmp) / repo_id + + self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors + + card = self.generate_model_card( + cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags + ) + card.save(str(saved_path / "README.md")) + + cfg.save_pretrained(saved_path) # Calls _save_pretrained and stores train config + + commit_info = api.upload_folder( + repo_id=repo_id, + repo_type="model", + folder_path=saved_path, + commit_message="Upload policy weights, train config and readme", + allow_patterns=["*.safetensors", "*.json", "*.yaml", "*.md"], + ignore_patterns=["*.tmp", "*.log"], + ) + + logging.info(f"Model pushed to {commit_info.repo_url.url}") + + def generate_model_card( + self, dataset_repo_id: str, model_type: str, license: str | None, tags: List[str] | None + ) -> ModelCard: + base_model = "lerobot/smolvla_base" if model_type == "smolvla" else None # Set a base model + + card_data = ModelCardData( + license=license or "apache-2.0", + library_name="lerobot", + pipeline_tag="robotics", + tags=list(set(tags or []).union({"robotics", "lerobot", model_type})), + model_name=model_type, + datasets=dataset_repo_id, + base_model=base_model, + ) + + template_card = files("lerobot.templates").joinpath("lerobot_modelcard_template.md").read_text() + card = ModelCard.from_template(card_data, template_str=template_card) + card.validate() + return card diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py new file mode 100644 index 000000000..db58beb2f --- /dev/null +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -0,0 +1,245 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.common.constants import ACTION, OBS_IMAGE, OBS_STATE +from lerobot.common.optim.optimizers import MultiAdamConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode + + +def is_image_feature(key: str) -> bool: + """Check if a feature key represents an image feature. + + Args: + key: The feature key to check + + Returns: + True if the key represents an image feature, False otherwise + """ + return key.startswith(OBS_IMAGE) + + +@dataclass +class ConcurrencyConfig: + """Configuration for the concurrency of the actor and learner. + Possible values are: + - "threads": Use threads for the actor and learner. + - "processes": Use processes for the actor and learner. + """ + + actor: str = "threads" + learner: str = "threads" + + +@dataclass +class ActorLearnerConfig: + learner_host: str = "127.0.0.1" + learner_port: int = 50051 + policy_parameters_push_frequency: int = 4 + queue_get_timeout: float = 2 + + +@dataclass +class CriticNetworkConfig: + hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) + activate_final: bool = True + final_activation: str | None = None + + +@dataclass +class ActorNetworkConfig: + hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) + activate_final: bool = True + + +@dataclass +class PolicyConfig: + use_tanh_squash: bool = True + std_min: float = 1e-5 + std_max: float = 10.0 + init_final: float = 0.05 + + +@PreTrainedConfig.register_subclass("sac") +@dataclass +class SACConfig(PreTrainedConfig): + """Soft Actor-Critic (SAC) configuration. + + SAC is an off-policy actor-critic deep RL algorithm based on the maximum entropy + reinforcement learning framework. It learns a policy and a Q-function simultaneously + using experience collected from the environment. + + This configuration class contains all the parameters needed to define a SAC agent, + including network architectures, optimization settings, and algorithm-specific + hyperparameters. + """ + + # Mapping of feature types to normalization modes + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.MIN_MAX, + "ENV": NormalizationMode.MIN_MAX, + "ACTION": NormalizationMode.MIN_MAX, + } + ) + + # Statistics for normalizing different types of inputs + dataset_stats: dict[str, dict[str, list[float]]] | None = field( + default_factory=lambda: { + OBS_IMAGE: { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + }, + OBS_STATE: { + "min": [0.0, 0.0], + "max": [1.0, 1.0], + }, + ACTION: { + "min": [0.0, 0.0, 0.0], + "max": [1.0, 1.0, 1.0], + }, + } + ) + + # Architecture specifics + # Device to run the model on (e.g., "cuda", "cpu") + device: str = "cpu" + # Device to store the model on + storage_device: str = "cpu" + # Name of the vision encoder model (Set to "helper2424/resnet10" for hil serl resnet10) + vision_encoder_name: str | None = None + # Whether to freeze the vision encoder during training + freeze_vision_encoder: bool = True + # Hidden dimension size for the image encoder + image_encoder_hidden_dim: int = 32 + # Whether to use a shared encoder for actor and critic + shared_encoder: bool = True + # Number of discrete actions, eg for gripper actions + num_discrete_actions: int | None = None + # Dimension of the image embedding pooling + image_embedding_pooling_dim: int = 8 + + # Training parameter + # Number of steps for online training + online_steps: int = 1000000 + # Seed for the online environment + online_env_seed: int = 10000 + # Capacity of the online replay buffer + online_buffer_capacity: int = 100000 + # Capacity of the offline replay buffer + offline_buffer_capacity: int = 100000 + # Whether to use asynchronous prefetching for the buffers + async_prefetch: bool = False + # Number of steps before learning starts + online_step_before_learning: int = 100 + # Frequency of policy updates + policy_update_freq: int = 1 + + # SAC algorithm parameters + # Discount factor for the SAC algorithm + discount: float = 0.99 + # Initial temperature value + temperature_init: float = 1.0 + # Number of critics in the ensemble + num_critics: int = 2 + # Number of subsampled critics for training + num_subsample_critics: int | None = None + # Learning rate for the critic network + critic_lr: float = 3e-4 + # Learning rate for the actor network + actor_lr: float = 3e-4 + # Learning rate for the temperature parameter + temperature_lr: float = 3e-4 + # Weight for the critic target update + critic_target_update_weight: float = 0.005 + # Update-to-data ratio for the UTD algorithm (If you want enable utd_ratio, you need to set it to >1) + utd_ratio: int = 1 + # Hidden dimension size for the state encoder + state_encoder_hidden_dim: int = 256 + # Dimension of the latent space + latent_dim: int = 256 + # Target entropy for the SAC algorithm + target_entropy: float | None = None + # Whether to use backup entropy for the SAC algorithm + use_backup_entropy: bool = True + # Gradient clipping norm for the SAC algorithm + grad_clip_norm: float = 40.0 + + # Network configuration + # Configuration for the critic network architecture + critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) + # Configuration for the actor network architecture + actor_network_kwargs: ActorNetworkConfig = field(default_factory=ActorNetworkConfig) + # Configuration for the policy parameters + policy_kwargs: PolicyConfig = field(default_factory=PolicyConfig) + # Configuration for the discrete critic network + discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig) + # Configuration for actor-learner architecture + actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig) + # Configuration for concurrency settings (you can use threads or processes for the actor and learner) + concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig) + + # Optimizations + use_torch_compile: bool = True + + def __post_init__(self): + super().__post_init__() + # Any validation specific to SAC configuration + + def get_optimizer_preset(self) -> MultiAdamConfig: + return MultiAdamConfig( + weight_decay=0.0, + optimizer_groups={ + "actor": {"lr": self.actor_lr}, + "critic": {"lr": self.critic_lr}, + "temperature": {"lr": self.temperature_lr}, + }, + ) + + def get_scheduler_preset(self) -> None: + return None + + def validate_features(self) -> None: + has_image = any(is_image_feature(key) for key in self.input_features) + has_state = OBS_STATE in self.input_features + + if not (has_state or has_image): + raise ValueError( + "You must provide either 'observation.state' or an image observation (key starting with 'observation.image') in the input features" + ) + + if "action" not in self.output_features: + raise ValueError("You must provide 'action' in the output features") + + @property + def image_features(self) -> list[str]: + return [key for key in self.input_features if is_image_feature(key)] + + @property + def observation_delta_indices(self) -> list: + return None + + @property + def action_delta_indices(self) -> list: + return None # SAC typically predicts one action at a time + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py new file mode 100644 index 000000000..1ca469351 --- /dev/null +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -0,0 +1,1116 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import asdict +from typing import Callable, Literal + +import einops +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +from torch import Tensor +from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution + +from lerobot.common.policies.normalize import NormalizeBuffer +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.sac.configuration_sac import SACConfig, is_image_feature +from lerobot.common.policies.utils import get_device_from_parameters + +DISCRETE_DIMENSION_INDEX = -1 # Gripper is always the last dimension + + +class SACPolicy( + PreTrainedPolicy, +): + config_class = SACConfig + name = "sac" + + def __init__( + self, + config: SACConfig | None = None, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + super().__init__(config) + config.validate_features() + self.config = config + + # Determine action dimension and initialize all components + continuous_action_dim = config.output_features["action"].shape[0] + self._init_normalization(dataset_stats) + self._init_encoders() + self._init_critics(continuous_action_dim) + self._init_actor(continuous_action_dim) + self._init_temperature() + + def get_optim_params(self) -> dict: + optim_params = { + "actor": [ + p + for n, p in self.actor.named_parameters() + if not n.startswith("encoder") or not self.shared_encoder + ], + "critic": self.critic_ensemble.parameters(), + "temperature": self.log_alpha, + } + if self.config.num_discrete_actions is not None: + optim_params["discrete_critic"] = self.discrete_critic.parameters() + return optim_params + + def reset(self): + """Reset the policy""" + pass + + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + raise NotImplementedError("SACPolicy does not support action chunking. It returns single actions!") + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select action for inference/evaluation""" + + observations_features = None + if self.shared_encoder and self.actor.encoder.has_images: + # Cache and normalize image features + observations_features = self.actor.encoder.get_cached_image_features(batch, normalize=True) + + actions, _, _ = self.actor(batch, observations_features) + + if self.config.num_discrete_actions is not None: + discrete_action_value = self.discrete_critic(batch, observations_features) + discrete_action = torch.argmax(discrete_action_value, dim=-1, keepdim=True) + actions = torch.cat([actions, discrete_action], dim=-1) + + return actions + + def critic_forward( + self, + observations: dict[str, Tensor], + actions: Tensor, + use_target: bool = False, + observation_features: Tensor | None = None, + ) -> Tensor: + """Forward pass through a critic network ensemble + + Args: + observations: Dictionary of observations + actions: Action tensor + use_target: If True, use target critics, otherwise use ensemble critics + + Returns: + Tensor of Q-values from all critics + """ + + critics = self.critic_target if use_target else self.critic_ensemble + q_values = critics(observations, actions, observation_features) + return q_values + + def discrete_critic_forward( + self, observations, use_target=False, observation_features=None + ) -> torch.Tensor: + """Forward pass through a discrete critic network + + Args: + observations: Dictionary of observations + use_target: If True, use target critics, otherwise use ensemble critics + observation_features: Optional pre-computed observation features to avoid recomputing encoder output + + Returns: + Tensor of Q-values from the discrete critic network + """ + discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic + q_values = discrete_critic(observations, observation_features) + return q_values + + def forward( + self, + batch: dict[str, Tensor | dict[str, Tensor]], + model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic", + ) -> dict[str, Tensor]: + """Compute the loss for the given model + + Args: + batch: Dictionary containing: + - action: Action tensor + - reward: Reward tensor + - state: Observations tensor dict + - next_state: Next observations tensor dict + - done: Done mask tensor + - observation_feature: Optional pre-computed observation features + - next_observation_feature: Optional pre-computed next observation features + model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature") + + Returns: + The computed loss tensor + """ + # Extract common components from batch + actions: Tensor = batch["action"] + observations: dict[str, Tensor] = batch["state"] + observation_features: Tensor = batch.get("observation_feature") + + if model == "critic": + # Extract critic-specific components + rewards: Tensor = batch["reward"] + next_observations: dict[str, Tensor] = batch["next_state"] + done: Tensor = batch["done"] + next_observation_features: Tensor = batch.get("next_observation_feature") + + loss_critic = self.compute_loss_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, + ) + + return {"loss_critic": loss_critic} + + if model == "discrete_critic" and self.config.num_discrete_actions is not None: + # Extract critic-specific components + rewards: Tensor = batch["reward"] + next_observations: dict[str, Tensor] = batch["next_state"] + done: Tensor = batch["done"] + next_observation_features: Tensor = batch.get("next_observation_feature") + complementary_info = batch.get("complementary_info") + loss_discrete_critic = self.compute_loss_discrete_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, + complementary_info=complementary_info, + ) + return {"loss_discrete_critic": loss_discrete_critic} + if model == "actor": + return { + "loss_actor": self.compute_loss_actor( + observations=observations, + observation_features=observation_features, + ) + } + + if model == "temperature": + return { + "loss_temperature": self.compute_loss_temperature( + observations=observations, + observation_features=observation_features, + ) + } + + raise ValueError(f"Unknown model type: {model}") + + def update_target_networks(self): + """Update target networks with exponential moving average""" + for target_param, param in zip( + self.critic_target.parameters(), + self.critic_ensemble.parameters(), + strict=True, + ): + target_param.data.copy_( + param.data * self.config.critic_target_update_weight + + target_param.data * (1.0 - self.config.critic_target_update_weight) + ) + if self.config.num_discrete_actions is not None: + for target_param, param in zip( + self.discrete_critic_target.parameters(), + self.discrete_critic.parameters(), + strict=True, + ): + target_param.data.copy_( + param.data * self.config.critic_target_update_weight + + target_param.data * (1.0 - self.config.critic_target_update_weight) + ) + + def update_temperature(self): + self.temperature = self.log_alpha.exp().item() + + def compute_loss_critic( + self, + observations, + actions, + rewards, + next_observations, + done, + observation_features: Tensor | None = None, + next_observation_features: Tensor | None = None, + ) -> Tensor: + with torch.no_grad(): + next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features) + + # 2- compute q targets + q_targets = self.critic_forward( + observations=next_observations, + actions=next_action_preds, + use_target=True, + observation_features=next_observation_features, + ) + + # subsample critics to prevent overfitting if use high UTD (update to date) + # TODO: Get indices before forward pass to avoid unnecessary computation + if self.config.num_subsample_critics is not None: + indices = torch.randperm(self.config.num_critics) + indices = indices[: self.config.num_subsample_critics] + q_targets = q_targets[indices] + + # critics subsample size + min_q, _ = q_targets.min(dim=0) # Get values from min operation + if self.config.use_backup_entropy: + min_q = min_q - (self.temperature * next_log_probs) + + td_target = rewards + (1 - done) * self.config.discount * min_q + + # 3- compute predicted qs + if self.config.num_discrete_actions is not None: + # NOTE: We only want to keep the continuous action part + # In the buffer we have the full action space (continuous + discrete) + # We need to split them before concatenating them in the critic forward + actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX] + q_preds = self.critic_forward( + observations=observations, + actions=actions, + use_target=False, + observation_features=observation_features, + ) + + # 4- Calculate loss + # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. + td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]) + # You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up + critics_loss = ( + F.mse_loss( + input=q_preds, + target=td_target_duplicate, + reduction="none", + ).mean(dim=1) + ).sum() + return critics_loss + + def compute_loss_discrete_critic( + self, + observations, + actions, + rewards, + next_observations, + done, + observation_features=None, + next_observation_features=None, + complementary_info=None, + ): + # NOTE: We only want to keep the discrete action part + # In the buffer we have the full action space (continuous + discrete) + # We need to split them before concatenating them in the critic forward + actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone() + actions_discrete = torch.round(actions_discrete) + actions_discrete = actions_discrete.long() + + discrete_penalties: Tensor | None = None + if complementary_info is not None: + discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty") + + with torch.no_grad(): + # For DQN, select actions using online network, evaluate with target network + next_discrete_qs = self.discrete_critic_forward( + next_observations, use_target=False, observation_features=next_observation_features + ) + best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True) + + # Get target Q-values from target network + target_next_discrete_qs = self.discrete_critic_forward( + observations=next_observations, + use_target=True, + observation_features=next_observation_features, + ) + + # Use gather to select Q-values for best actions + target_next_discrete_q = torch.gather( + target_next_discrete_qs, dim=1, index=best_next_discrete_action + ).squeeze(-1) + + # Compute target Q-value with Bellman equation + rewards_discrete = rewards + if discrete_penalties is not None: + rewards_discrete = rewards + discrete_penalties + target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q + + # Get predicted Q-values for current observations + predicted_discrete_qs = self.discrete_critic_forward( + observations=observations, use_target=False, observation_features=observation_features + ) + + # Use gather to select Q-values for taken actions + predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1) + + # Compute MSE loss between predicted and target Q-values + discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q) + return discrete_critic_loss + + def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor: + """Compute the temperature loss""" + # calculate temperature loss + with torch.no_grad(): + _, log_probs, _ = self.actor(observations, observation_features) + temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean() + return temperature_loss + + def compute_loss_actor( + self, + observations, + observation_features: Tensor | None = None, + ) -> Tensor: + actions_pi, log_probs, _ = self.actor(observations, observation_features) + + q_preds = self.critic_forward( + observations=observations, + actions=actions_pi, + use_target=False, + observation_features=observation_features, + ) + min_q_preds = q_preds.min(dim=0)[0] + + actor_loss = ((self.temperature * log_probs) - min_q_preds).mean() + return actor_loss + + def _init_normalization(self, dataset_stats): + """Initialize input/output normalization modules.""" + self.normalize_inputs = nn.Identity() + self.normalize_targets = nn.Identity() + if self.config.dataset_stats is not None: + params = _convert_normalization_params_to_tensor(self.config.dataset_stats) + self.normalize_inputs = NormalizeBuffer( + self.config.input_features, self.config.normalization_mapping, params + ) + stats = dataset_stats or params + self.normalize_targets = NormalizeBuffer( + self.config.output_features, self.config.normalization_mapping, stats + ) + + def _init_encoders(self): + """Initialize shared or separate encoders for actor and critic.""" + self.shared_encoder = self.config.shared_encoder + self.encoder_critic = SACObservationEncoder(self.config, self.normalize_inputs) + self.encoder_actor = ( + self.encoder_critic + if self.shared_encoder + else SACObservationEncoder(self.config, self.normalize_inputs) + ) + + def _init_critics(self, continuous_action_dim): + """Build critic ensemble, targets, and optional discrete critic.""" + heads = [ + CriticHead( + input_dim=self.encoder_critic.output_dim + continuous_action_dim, + **asdict(self.config.critic_network_kwargs), + ) + for _ in range(self.config.num_critics) + ] + self.critic_ensemble = CriticEnsemble( + encoder=self.encoder_critic, ensemble=heads, output_normalization=self.normalize_targets + ) + target_heads = [ + CriticHead( + input_dim=self.encoder_critic.output_dim + continuous_action_dim, + **asdict(self.config.critic_network_kwargs), + ) + for _ in range(self.config.num_critics) + ] + self.critic_target = CriticEnsemble( + encoder=self.encoder_critic, ensemble=target_heads, output_normalization=self.normalize_targets + ) + self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) + + if self.config.use_torch_compile: + self.critic_ensemble = torch.compile(self.critic_ensemble) + self.critic_target = torch.compile(self.critic_target) + + if self.config.num_discrete_actions is not None: + self._init_discrete_critics() + + def _init_discrete_critics(self): + """Build discrete discrete critic ensemble and target networks.""" + self.discrete_critic = DiscreteCritic( + encoder=self.encoder_critic, + input_dim=self.encoder_critic.output_dim, + output_dim=self.config.num_discrete_actions, + **asdict(self.config.discrete_critic_network_kwargs), + ) + self.discrete_critic_target = DiscreteCritic( + encoder=self.encoder_critic, + input_dim=self.encoder_critic.output_dim, + output_dim=self.config.num_discrete_actions, + **asdict(self.config.discrete_critic_network_kwargs), + ) + + # TODO: (maractingi, azouitine) Compile the discrete critic + self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict()) + + def _init_actor(self, continuous_action_dim): + """Initialize policy actor network and default target entropy.""" + # NOTE: The actor select only the continuous action part + self.actor = Policy( + encoder=self.encoder_actor, + network=MLP(input_dim=self.encoder_actor.output_dim, **asdict(self.config.actor_network_kwargs)), + action_dim=continuous_action_dim, + encoder_is_shared=self.shared_encoder, + **asdict(self.config.policy_kwargs), + ) + + self.target_entropy = self.config.target_entropy + if self.target_entropy is None: + dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0) + self.target_entropy = -np.prod(dim) / 2 + + def _init_temperature(self): + """Set up temperature parameter and initial log_alpha.""" + temp_init = self.config.temperature_init + self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)])) + self.temperature = self.log_alpha.exp().item() + + +class SACObservationEncoder(nn.Module): + """Encode image and/or state vector observations.""" + + def __init__(self, config: SACConfig, input_normalizer: nn.Module) -> None: + super().__init__() + self.config = config + self.input_normalization = input_normalizer + self._init_image_layers() + self._init_state_layers() + self._compute_output_dim() + + def _init_image_layers(self) -> None: + self.image_keys = [k for k in self.config.input_features if is_image_feature(k)] + self.has_images = bool(self.image_keys) + if not self.has_images: + return + + if self.config.vision_encoder_name is not None: + self.image_encoder = PretrainedImageEncoder(self.config) + else: + self.image_encoder = DefaultImageEncoder(self.config) + + if self.config.freeze_vision_encoder: + freeze_image_encoder(self.image_encoder) + + dummy = torch.zeros(1, *self.config.input_features[self.image_keys[0]].shape) + with torch.no_grad(): + _, channels, height, width = self.image_encoder(dummy).shape + + self.spatial_embeddings = nn.ModuleDict() + self.post_encoders = nn.ModuleDict() + + for key in self.image_keys: + name = key.replace(".", "_") + self.spatial_embeddings[name] = SpatialLearnedEmbeddings( + height=height, + width=width, + channel=channels, + num_features=self.config.image_embedding_pooling_dim, + ) + self.post_encoders[name] = nn.Sequential( + nn.Dropout(0.1), + nn.Linear( + in_features=channels * self.config.image_embedding_pooling_dim, + out_features=self.config.latent_dim, + ), + nn.LayerNorm(normalized_shape=self.config.latent_dim), + nn.Tanh(), + ) + + def _init_state_layers(self) -> None: + self.has_env = "observation.environment_state" in self.config.input_features + self.has_state = "observation.state" in self.config.input_features + if self.has_env: + dim = self.config.input_features["observation.environment_state"].shape[0] + self.env_encoder = nn.Sequential( + nn.Linear(dim, self.config.latent_dim), + nn.LayerNorm(self.config.latent_dim), + nn.Tanh(), + ) + if self.has_state: + dim = self.config.input_features["observation.state"].shape[0] + self.state_encoder = nn.Sequential( + nn.Linear(dim, self.config.latent_dim), + nn.LayerNorm(self.config.latent_dim), + nn.Tanh(), + ) + + def _compute_output_dim(self) -> None: + out = 0 + if self.has_images: + out += len(self.image_keys) * self.config.latent_dim + if self.has_env: + out += self.config.latent_dim + if self.has_state: + out += self.config.latent_dim + self._out_dim = out + + def forward( + self, obs: dict[str, Tensor], cache: dict[str, Tensor] | None = None, detach: bool = False + ) -> Tensor: + obs = self.input_normalization(obs) + parts = [] + if self.has_images: + if cache is None: + cache = self.get_cached_image_features(obs, normalize=False) + parts.append(self._encode_images(cache, detach)) + if self.has_env: + parts.append(self.env_encoder(obs["observation.environment_state"])) + if self.has_state: + parts.append(self.state_encoder(obs["observation.state"])) + if parts: + return torch.cat(parts, dim=-1) + + raise ValueError( + "No parts to concatenate, you should have at least one image or environment state or state" + ) + + def get_cached_image_features(self, obs: dict[str, Tensor], normalize: bool = False) -> dict[str, Tensor]: + """Extract and optionally cache image features from observations. + + This function processes image observations through the vision encoder once and returns + the resulting features. + When the image encoder is shared between actor and critics AND frozen, these features can be safely cached and + reused across policy components (actor, critic, discrete_critic), avoiding redundant forward passes. + + Performance impact: + - The vision encoder forward pass is typically the main computational bottleneck during training and inference + - Caching these features can provide 2-4x speedup in training and inference + + Normalization behavior: + - When called from inside forward(): set normalize=False since inputs are already normalized + - When called from outside forward(): set normalize=True to ensure proper input normalization + + Usage patterns: + - Called in select_action() with normalize=True + - Called in learner.py's get_observation_features() to pre-compute features for all policy components + - Called internally by forward() with normalize=False + + Args: + obs: Dictionary of observation tensors containing image keys + normalize: Whether to normalize observations before encoding + Set to True when calling directly from outside the encoder's forward method + Set to False when calling from within forward() where inputs are already normalized + + Returns: + Dictionary mapping image keys to their corresponding encoded features + """ + if normalize: + obs = self.input_normalization(obs) + batched = torch.cat([obs[k] for k in self.image_keys], dim=0) + out = self.image_encoder(batched) + chunks = torch.chunk(out, len(self.image_keys), dim=0) + return dict(zip(self.image_keys, chunks, strict=False)) + + def _encode_images(self, cache: dict[str, Tensor], detach: bool) -> Tensor: + """Encode image features from cached observations. + + This function takes pre-encoded image features from the cache and applies spatial embeddings and post-encoders. + It also supports detaching the encoded features if specified. + + Args: + cache (dict[str, Tensor]): The cached image features. + detach (bool): Usually when the encoder is shared between actor and critics, + we want to detach the encoded features on the policy side to avoid backprop through the encoder. + More detail here `https://cdn.aaai.org/ojs/17276/17276-13-20770-1-2-20210518.pdf` + + Returns: + Tensor: The encoded image features. + """ + feats = [] + for k, feat in cache.items(): + safe_key = k.replace(".", "_") + x = self.spatial_embeddings[safe_key](feat) + x = self.post_encoders[safe_key](x) + if detach: + x = x.detach() + feats.append(x) + return torch.cat(feats, dim=-1) + + @property + def output_dim(self) -> int: + return self._out_dim + + +class MLP(nn.Module): + """Multi-layer perceptron builder. + + Dynamically constructs a sequence of layers based on `hidden_dims`: + 1) Linear (in_dim -> out_dim) + 2) Optional Dropout if `dropout_rate` > 0 and (not final layer or `activate_final`) + 3) LayerNorm on the output features + 4) Activation (standard for intermediate layers, `final_activation` for last layer if `activate_final`) + + Arguments: + input_dim (int): Size of input feature dimension. + hidden_dims (list[int]): Sizes for each hidden layer. + activations (Callable or str): Activation to apply between layers. + activate_final (bool): Whether to apply activation at the final layer. + dropout_rate (Optional[float]): Dropout probability applied before normalization and activation. + final_activation (Optional[Callable or str]): Activation for the final layer when `activate_final` is True. + + For each layer, `in_dim` is updated to the previous `out_dim`. All constructed modules are + stored in `self.net` as an `nn.Sequential` container. + """ + + def __init__( + self, + input_dim: int, + hidden_dims: list[int], + activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), + activate_final: bool = False, + dropout_rate: float | None = None, + final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, + ): + super().__init__() + layers: list[nn.Module] = [] + in_dim = input_dim + total = len(hidden_dims) + + for idx, out_dim in enumerate(hidden_dims): + # 1) linear transform + layers.append(nn.Linear(in_dim, out_dim)) + + is_last = idx == total - 1 + # 2-4) optionally add dropout, normalization, and activation + if not is_last or activate_final: + if dropout_rate and dropout_rate > 0: + layers.append(nn.Dropout(p=dropout_rate)) + layers.append(nn.LayerNorm(out_dim)) + act_cls = final_activation if is_last and final_activation else activations + act = act_cls if isinstance(act_cls, nn.Module) else getattr(nn, act_cls)() + layers.append(act) + + in_dim = out_dim + + self.net = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.net(x) + + +class CriticHead(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dims: list[int], + activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), + activate_final: bool = False, + dropout_rate: float | None = None, + init_final: float | None = None, + final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, + ): + super().__init__() + self.net = MLP( + input_dim=input_dim, + hidden_dims=hidden_dims, + activations=activations, + activate_final=activate_final, + dropout_rate=dropout_rate, + final_activation=final_activation, + ) + self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1) + if init_final is not None: + nn.init.uniform_(self.output_layer.weight, -init_final, init_final) + nn.init.uniform_(self.output_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.output_layer.weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.output_layer(self.net(x)) + + +class CriticEnsemble(nn.Module): + """ + CriticEnsemble wraps multiple CriticHead modules into an ensemble. + + Args: + encoder (SACObservationEncoder): encoder for observations. + ensemble (List[CriticHead]): list of critic heads. + output_normalization (nn.Module): normalization layer for actions. + init_final (float | None): optional initializer scale for final layers. + + Forward returns a tensor of shape (num_critics, batch_size) containing Q-values. + """ + + def __init__( + self, + encoder: SACObservationEncoder, + ensemble: list[CriticHead], + output_normalization: nn.Module, + init_final: float | None = None, + ): + super().__init__() + self.encoder = encoder + self.init_final = init_final + self.output_normalization = output_normalization + self.critics = nn.ModuleList(ensemble) + + def forward( + self, + observations: dict[str, torch.Tensor], + actions: torch.Tensor, + observation_features: torch.Tensor | None = None, + ) -> torch.Tensor: + device = get_device_from_parameters(self) + # Move each tensor in observations to device + observations = {k: v.to(device) for k, v in observations.items()} + # NOTE: We normalize actions it helps for sample efficiency + actions: dict[str, torch.tensor] = {"action": actions} + # NOTE: Normalization layer took dict in input and outputs a dict that why + actions = self.output_normalization(actions)["action"] + actions = actions.to(device) + + obs_enc = self.encoder(observations, cache=observation_features) + + inputs = torch.cat([obs_enc, actions], dim=-1) + + # Loop through critics and collect outputs + q_values = [] + for critic in self.critics: + q_values.append(critic(inputs)) + + # Stack outputs to match expected shape [num_critics, batch_size] + q_values = torch.stack([q.squeeze(-1) for q in q_values], dim=0) + return q_values + + +class DiscreteCritic(nn.Module): + def __init__( + self, + encoder: nn.Module, + input_dim: int, + hidden_dims: list[int], + output_dim: int = 3, + activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), + activate_final: bool = False, + dropout_rate: float | None = None, + init_final: float | None = None, + final_activation: Callable[[torch.Tensor], torch.Tensor] | str | None = None, + ): + super().__init__() + self.encoder = encoder + self.output_dim = output_dim + + self.net = MLP( + input_dim=input_dim, + hidden_dims=hidden_dims, + activations=activations, + activate_final=activate_final, + dropout_rate=dropout_rate, + final_activation=final_activation, + ) + + self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=self.output_dim) + if init_final is not None: + nn.init.uniform_(self.output_layer.weight, -init_final, init_final) + nn.init.uniform_(self.output_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.output_layer.weight) + + def forward( + self, observations: torch.Tensor, observation_features: torch.Tensor | None = None + ) -> torch.Tensor: + device = get_device_from_parameters(self) + observations = {k: v.to(device) for k, v in observations.items()} + obs_enc = self.encoder(observations, cache=observation_features) + return self.output_layer(self.net(obs_enc)) + + +class Policy(nn.Module): + def __init__( + self, + encoder: SACObservationEncoder, + network: nn.Module, + action_dim: int, + std_min: float = -5, + std_max: float = 2, + fixed_std: torch.Tensor | None = None, + init_final: float | None = None, + use_tanh_squash: bool = False, + encoder_is_shared: bool = False, + ): + super().__init__() + self.encoder: SACObservationEncoder = encoder + self.network = network + self.action_dim = action_dim + self.std_min = std_min + self.std_max = std_max + self.fixed_std = fixed_std + self.use_tanh_squash = use_tanh_squash + self.encoder_is_shared = encoder_is_shared + + # Find the last Linear layer's output dimension + for layer in reversed(network.net): + if isinstance(layer, nn.Linear): + out_features = layer.out_features + break + # Mean layer + self.mean_layer = nn.Linear(out_features, action_dim) + if init_final is not None: + nn.init.uniform_(self.mean_layer.weight, -init_final, init_final) + nn.init.uniform_(self.mean_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.mean_layer.weight) + + # Standard deviation layer or parameter + if fixed_std is None: + self.std_layer = nn.Linear(out_features, action_dim) + if init_final is not None: + nn.init.uniform_(self.std_layer.weight, -init_final, init_final) + nn.init.uniform_(self.std_layer.bias, -init_final, init_final) + else: + orthogonal_init()(self.std_layer.weight) + + def forward( + self, + observations: torch.Tensor, + observation_features: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + # We detach the encoder if it is shared to avoid backprop through it + # This is important to avoid the encoder to be updated through the policy + obs_enc = self.encoder(observations, cache=observation_features, detach=self.encoder_is_shared) + + # Get network outputs + outputs = self.network(obs_enc) + means = self.mean_layer(outputs) + + # Compute standard deviations + if self.fixed_std is None: + log_std = self.std_layer(outputs) + std = torch.exp(log_std) # Match JAX "exp" + std = torch.clamp(std, self.std_min, self.std_max) # Match JAX default clip + else: + std = self.fixed_std.expand_as(means) + + # Build transformed distribution + dist = TanhMultivariateNormalDiag(loc=means, scale_diag=std) + + # Sample actions (reparameterized) + actions = dist.rsample() + + # Compute log_probs + log_probs = dist.log_prob(actions) + + return actions, log_probs, means + + def get_features(self, observations: torch.Tensor) -> torch.Tensor: + """Get encoded features from observations""" + device = get_device_from_parameters(self) + observations = observations.to(device) + if self.encoder is not None: + with torch.inference_mode(): + return self.encoder(observations) + return observations + + +class DefaultImageEncoder(nn.Module): + def __init__(self, config: SACConfig): + super().__init__() + image_key = next(key for key in config.input_features if is_image_feature(key)) + self.image_enc_layers = nn.Sequential( + nn.Conv2d( + in_channels=config.input_features[image_key].shape[0], + out_channels=config.image_encoder_hidden_dim, + kernel_size=7, + stride=2, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=5, + stride=2, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=3, + stride=2, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=3, + stride=2, + ), + nn.ReLU(), + ) + + def forward(self, x): + x = self.image_enc_layers(x) + return x + + +def freeze_image_encoder(image_encoder: nn.Module): + """Freeze all parameters in the encoder""" + for param in image_encoder.parameters(): + param.requires_grad = False + + +class PretrainedImageEncoder(nn.Module): + def __init__(self, config: SACConfig): + super().__init__() + + self.image_enc_layers, self.image_enc_out_shape = self._load_pretrained_vision_encoder(config) + + def _load_pretrained_vision_encoder(self, config: SACConfig): + """Set up CNN encoder""" + from transformers import AutoModel + + self.image_enc_layers = AutoModel.from_pretrained(config.vision_encoder_name, trust_remote_code=True) + + if hasattr(self.image_enc_layers.config, "hidden_sizes"): + self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[-1] # Last channel dimension + elif hasattr(self.image_enc_layers, "fc"): + self.image_enc_out_shape = self.image_enc_layers.fc.in_features + else: + raise ValueError("Unsupported vision encoder architecture, make sure you are using a CNN") + return self.image_enc_layers, self.image_enc_out_shape + + def forward(self, x): + enc_feat = self.image_enc_layers(x).last_hidden_state + return enc_feat + + +def orthogonal_init(): + return lambda x: torch.nn.init.orthogonal_(x, gain=1.0) + + +class SpatialLearnedEmbeddings(nn.Module): + def __init__(self, height, width, channel, num_features=8): + """ + PyTorch implementation of learned spatial embeddings + + Args: + height: Spatial height of input features + width: Spatial width of input features + channel: Number of input channels + num_features: Number of output embedding dimensions + """ + super().__init__() + self.height = height + self.width = width + self.channel = channel + self.num_features = num_features + + self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features)) + + nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear") + + def forward(self, features): + """ + Forward pass for spatial embedding + + Args: + features: Input tensor of shape [B, C, H, W] where B is batch size, + C is number of channels, H is height, and W is width + Returns: + Output tensor of shape [B, C*F] where F is the number of features + """ + + features_expanded = features.unsqueeze(-1) # [B, C, H, W, 1] + kernel_expanded = self.kernel.unsqueeze(0) # [1, C, H, W, F] + + # Element-wise multiplication and spatial reduction + output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum over H,W dimensions + + # Reshape to combine channel and feature dimensions + output = output.view(output.size(0), -1) # [B, C*F] + + return output + + +class RescaleFromTanh(Transform): + def __init__(self, low: float = -1, high: float = 1): + super().__init__() + + self.low = low + + self.high = high + + def _call(self, x): + # Rescale from (-1, 1) to (low, high) + + return 0.5 * (x + 1.0) * (self.high - self.low) + self.low + + def _inverse(self, y): + # Rescale from (low, high) back to (-1, 1) + + return 2.0 * (y - self.low) / (self.high - self.low) - 1.0 + + def log_abs_det_jacobian(self, x, y): + # log|d(rescale)/dx| = sum(log(0.5 * (high - low))) + + scale = 0.5 * (self.high - self.low) + + return torch.sum(torch.log(scale), dim=-1) + + +class TanhMultivariateNormalDiag(TransformedDistribution): + def __init__(self, loc, scale_diag, low=None, high=None): + base_dist = MultivariateNormal(loc, torch.diag_embed(scale_diag)) + + transforms = [TanhTransform(cache_size=1)] + + if low is not None and high is not None: + low = torch.as_tensor(low) + + high = torch.as_tensor(high) + + transforms.insert(0, RescaleFromTanh(low, high)) + + super().__init__(base_dist, transforms) + + def mode(self): + # Mode is mean of base distribution, passed through transforms + + x = self.base_dist.mean + + for transform in self.transforms: + x = transform(x) + + return x + + def stddev(self): + std = self.base_dist.stddev + + x = std + + for transform in self.transforms: + x = transform(x) + + return x + + +def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict: + converted_params = {} + for outer_key, inner_dict in normalization_params.items(): + converted_params[outer_key] = {} + for key, value in inner_dict.items(): + converted_params[outer_key][key] = torch.tensor(value) + if "image" in outer_key: + converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1) + + return converted_params diff --git a/lerobot/common/policies/sac/reward_model/configuration_classifier.py b/lerobot/common/policies/sac/reward_model/configuration_classifier.py new file mode 100644 index 000000000..6e2a551d4 --- /dev/null +++ b/lerobot/common/policies/sac/reward_model/configuration_classifier.py @@ -0,0 +1,76 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field + +from lerobot.common.optim.optimizers import AdamWConfig, OptimizerConfig +from lerobot.common.optim.schedulers import LRSchedulerConfig +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import NormalizationMode + + +@PreTrainedConfig.register_subclass(name="reward_classifier") +@dataclass +class RewardClassifierConfig(PreTrainedConfig): + """Configuration for the Reward Classifier model.""" + + name: str = "reward_classifier" + num_classes: int = 2 + hidden_dim: int = 256 + latent_dim: int = 256 + image_embedding_pooling_dim: int = 8 + dropout_rate: float = 0.1 + model_name: str = "helper2424/resnet10" + device: str = "cpu" + model_type: str = "cnn" # "transformer" or "cnn" + num_cameras: int = 2 + learning_rate: float = 1e-4 + weight_decay: float = 0.01 + grad_clip_norm: float = 1.0 + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.MEAN_STD, + } + ) + + @property + def observation_delta_indices(self) -> list | None: + return None + + @property + def action_delta_indices(self) -> list | None: + return None + + @property + def reward_delta_indices(self) -> list | None: + return None + + def get_optimizer_preset(self) -> OptimizerConfig: + return AdamWConfig( + lr=self.learning_rate, + weight_decay=self.weight_decay, + grad_clip_norm=self.grad_clip_norm, + ) + + def get_scheduler_preset(self) -> LRSchedulerConfig | None: + return None + + def validate_features(self) -> None: + """Validate feature configurations.""" + has_image = any(key.startswith("observation.image") for key in self.input_features) + if not has_image: + raise ValueError( + "You must provide an image observation (key starting with 'observation.image') in the input features" + ) diff --git a/lerobot/common/policies/sac/reward_model/modeling_classifier.py b/lerobot/common/policies/sac/reward_model/modeling_classifier.py new file mode 100644 index 000000000..7fec67f1a --- /dev/null +++ b/lerobot/common/policies/sac/reward_model/modeling_classifier.py @@ -0,0 +1,323 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch +from torch import Tensor, nn + +from lerobot.common.constants import OBS_IMAGE, REWARD +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig + + +class ClassifierOutput: + """Wrapper for classifier outputs with additional metadata.""" + + def __init__( + self, + logits: Tensor, + probabilities: Tensor | None = None, + hidden_states: Tensor | None = None, + ): + self.logits = logits + self.probabilities = probabilities + self.hidden_states = hidden_states + + def __repr__(self): + return ( + f"ClassifierOutput(logits={self.logits}, " + f"probabilities={self.probabilities}, " + f"hidden_states={self.hidden_states})" + ) + + +class SpatialLearnedEmbeddings(nn.Module): + def __init__(self, height, width, channel, num_features=8): + """ + PyTorch implementation of learned spatial embeddings + + Args: + height: Spatial height of input features + width: Spatial width of input features + channel: Number of input channels + num_features: Number of output embedding dimensions + """ + super().__init__() + self.height = height + self.width = width + self.channel = channel + self.num_features = num_features + + self.kernel = nn.Parameter(torch.empty(channel, height, width, num_features)) + + nn.init.kaiming_normal_(self.kernel, mode="fan_in", nonlinearity="linear") + + def forward(self, features): + """ + Forward pass for spatial embedding + + Args: + features: Input tensor of shape [B, H, W, C] or [H, W, C] if no batch + Returns: + Output tensor of shape [B, C*F] or [C*F] if no batch + """ + + features = features.last_hidden_state + + original_shape = features.shape + if features.dim() == 3: + features = features.unsqueeze(0) # Add batch dim + + features_expanded = features.unsqueeze(-1) # [B, H, W, C, 1] + kernel_expanded = self.kernel.unsqueeze(0) # [1, H, W, C, F] + + # Element-wise multiplication and spatial reduction + output = (features_expanded * kernel_expanded).sum(dim=(2, 3)) # Sum H,W + + # Reshape to combine channel and feature dimensions + output = output.view(output.size(0), -1) # [B, C*F] + + # Remove batch dim + if len(original_shape) == 3: + output = output.squeeze(0) + + return output + + +class Classifier(PreTrainedPolicy): + """Image classifier built on top of a pre-trained encoder.""" + + name = "reward_classifier" + config_class = RewardClassifierConfig + + def __init__( + self, + config: RewardClassifierConfig, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + from transformers import AutoModel + + super().__init__(config) + self.config = config + + # Initialize normalization (standardized with the policy framework) + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + # Set up encoder + encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True) + # Extract vision model if we're given a multimodal model + if hasattr(encoder, "vision_model"): + logging.info("Multimodal model detected - using vision encoder only") + self.encoder = encoder.vision_model + self.vision_config = encoder.config.vision_config + else: + self.encoder = encoder + self.vision_config = getattr(encoder, "config", None) + + # Model type from config + self.is_cnn = self.config.model_type == "cnn" + + # For CNNs, initialize backbone + if self.is_cnn: + self._setup_cnn_backbone() + + self._freeze_encoder() + + # Extract image keys from input_features + self.image_keys = [ + key.replace(".", "_") for key in config.input_features if key.startswith(OBS_IMAGE) + ] + + if self.is_cnn: + self.encoders = nn.ModuleDict() + for image_key in self.image_keys: + encoder = self._create_single_encoder() + self.encoders[image_key] = encoder + + self._build_classifier_head() + + def _setup_cnn_backbone(self): + """Set up CNN encoder""" + if hasattr(self.encoder, "fc"): + self.feature_dim = self.encoder.fc.in_features + self.encoder = nn.Sequential(*list(self.encoder.children())[:-1]) + elif hasattr(self.encoder.config, "hidden_sizes"): + self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension + else: + raise ValueError("Unsupported CNN architecture") + + def _freeze_encoder(self) -> None: + """Freeze the encoder parameters.""" + for param in self.encoder.parameters(): + param.requires_grad = False + + def _create_single_encoder(self): + encoder = nn.Sequential( + self.encoder, + SpatialLearnedEmbeddings( + height=4, + width=4, + channel=self.feature_dim, + num_features=self.config.image_embedding_pooling_dim, + ), + nn.Dropout(self.config.dropout_rate), + nn.Linear(self.feature_dim * self.config.image_embedding_pooling_dim, self.config.latent_dim), + nn.LayerNorm(self.config.latent_dim), + nn.Tanh(), + ) + + return encoder + + def _build_classifier_head(self) -> None: + """Initialize the classifier head architecture.""" + # Get input dimension based on model type + if self.is_cnn: + input_dim = self.config.latent_dim + else: # Transformer models + if hasattr(self.encoder.config, "hidden_size"): + input_dim = self.encoder.config.hidden_size + else: + raise ValueError("Unsupported transformer architecture since hidden_size is not found") + + self.classifier_head = nn.Sequential( + nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim), + nn.Dropout(self.config.dropout_rate), + nn.LayerNorm(self.config.hidden_dim), + nn.ReLU(), + nn.Linear( + self.config.hidden_dim, + 1 if self.config.num_classes == 2 else self.config.num_classes, + ), + ) + + def _get_encoder_output(self, x: torch.Tensor, image_key: str) -> torch.Tensor: + """Extract the appropriate output from the encoder.""" + with torch.no_grad(): + if self.is_cnn: + # The HF ResNet applies pooling internally + outputs = self.encoders[image_key](x) + return outputs + else: # Transformer models + outputs = self.encoder(x) + return outputs.last_hidden_state[:, 0, :] + + def extract_images_and_labels(self, batch: dict[str, Tensor]) -> tuple[list, Tensor]: + """Extract image tensors and label tensors from batch.""" + # Check for both OBS_IMAGE and OBS_IMAGES prefixes + images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)] + labels = batch[REWARD] + + return images, labels + + def predict(self, xs: list) -> ClassifierOutput: + """Forward pass of the classifier for inference.""" + encoder_outputs = torch.hstack( + [self._get_encoder_output(x, img_key) for x, img_key in zip(xs, self.image_keys, strict=True)] + ) + logits = self.classifier_head(encoder_outputs) + + if self.config.num_classes == 2: + logits = logits.squeeze(-1) + probabilities = torch.sigmoid(logits) + else: + probabilities = torch.softmax(logits, dim=-1) + + return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs) + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Tensor]]: + """Standard forward pass for training compatible with train.py.""" + # Normalize inputs if needed + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + + # Extract images and labels + images, labels = self.extract_images_and_labels(batch) + + # Get predictions + outputs = self.predict(images) + + # Calculate loss + if self.config.num_classes == 2: + # Binary classification + loss = nn.functional.binary_cross_entropy_with_logits(outputs.logits, labels) + predictions = (torch.sigmoid(outputs.logits) > 0.5).float() + else: + # Multi-class classification + loss = nn.functional.cross_entropy(outputs.logits, labels.long()) + predictions = torch.argmax(outputs.logits, dim=1) + + # Calculate accuracy for logging + correct = (predictions == labels).sum().item() + total = labels.size(0) + accuracy = 100 * correct / total + + # Return loss and metrics for logging + output_dict = { + "accuracy": accuracy, + "correct": correct, + "total": total, + } + + return loss, output_dict + + def predict_reward(self, batch, threshold=0.5): + """Eval method. Returns predicted reward with the decision threshold as argument.""" + # Check for both OBS_IMAGE and OBS_IMAGES prefixes + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + + # Extract images from batch dict + images = [batch[key] for key in self.config.input_features if key.startswith(OBS_IMAGE)] + + if self.config.num_classes == 2: + probs = self.predict(images).probabilities + logging.debug(f"Predicted reward images: {probs}") + return (probs > threshold).float() + else: + return torch.argmax(self.predict(images).probabilities, dim=1) + + def get_optim_params(self): + """Return optimizer parameters for the policy.""" + return self.parameters() + + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """ + This method is required by PreTrainedPolicy but not used for reward classifiers. + The reward classifier is not an actor and does not select actions. + """ + raise NotImplementedError("Reward classifiers do not select actions") + + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """ + This method is required by PreTrainedPolicy but not used for reward classifiers. + The reward classifier is not an actor and does not produce action chunks. + """ + raise NotImplementedError("Reward classifiers do not predict action chunks") + + def reset(self): + """ + This method is required by PreTrainedPolicy but not used for reward classifiers. + The reward classifier is not an actor and does not select actions. + """ + pass diff --git a/lerobot/common/policies/smolvla/configuration_smolvla.py b/lerobot/common/policies/smolvla/configuration_smolvla.py new file mode 100644 index 000000000..5996cf2e7 --- /dev/null +++ b/lerobot/common/policies/smolvla/configuration_smolvla.py @@ -0,0 +1,154 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.common.optim.optimizers import AdamWConfig +from lerobot.common.optim.schedulers import ( + CosineDecayWithWarmupSchedulerConfig, +) +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature + + +@PreTrainedConfig.register_subclass("smolvla") +@dataclass +class SmolVLAConfig(PreTrainedConfig): + # Input / output structure. + n_obs_steps: int = 1 + chunk_size: int = 50 + n_action_steps: int = 50 + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, + "ACTION": NormalizationMode.MEAN_STD, + } + ) + + # Shorter state and action vectors will be padded + max_state_dim: int = 32 + max_action_dim: int = 32 + + # Image preprocessing + resize_imgs_with_padding: tuple[int, int] = (512, 512) + + # Add empty images. Used by smolvla_aloha_sim which adds the empty + # left and right wrist cameras in addition to the top camera. + empty_cameras: int = 0 + + # Converts the joint and gripper values from the standard Aloha space to + # the space used by the pi internal runtime which was used to train the base model. + adapt_to_pi_aloha: bool = False + + # Converts joint dimensions to deltas with respect to the current state before passing to the model. + # Gripper dimensions will remain in absolute values. + use_delta_joint_actions_aloha: bool = False + + # Tokenizer + tokenizer_max_length: int = 48 + + # Decoding + num_steps: int = 10 + + # Attention utils + use_cache: bool = True + + # Finetuning settings + freeze_vision_encoder: bool = True + train_expert_only: bool = True + train_state_proj: bool = True + + # Training presets + optimizer_lr: float = 1e-4 + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 1e-10 + optimizer_grad_clip_norm: float = 10 + + scheduler_warmup_steps: int = 1_000 + scheduler_decay_steps: int = 30_000 + scheduler_decay_lr: float = 2.5e-6 + + vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone. + load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights + + add_image_special_tokens: bool = False # Whether to use special image tokens around image features. + + attention_mode: str = "cross_attn" + + prefix_length: int = -1 + + pad_language_to: str = "longest" # "max_length" + + num_expert_layers: int = -1 # Less or equal to 0 is the default where the action expert has the same number of layers of VLM. Otherwise the expert have less layers. + num_vlm_layers: int = 16 # Number of layers used in the VLM (first num_vlm_layers layers) + self_attn_every_n_layers: int = 2 # Interleave SA layers each self_attn_every_n_layers + expert_width_multiplier: float = 0.75 # The action expert hidden size (wrt to the VLM) + + min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding + max_period: float = 4.0 + + def __post_init__(self): + super().__post_init__() + + """Input validation (not exhaustive).""" + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"The chunk size is the upper bound for the number of action steps per model invocation. Got " + f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." + ) + if self.use_delta_joint_actions_aloha: + raise NotImplementedError( + "`use_delta_joint_actions_aloha` is used by smolvla for aloha real models. It is not ported yet in LeRobot." + ) + + def validate_features(self) -> None: + for i in range(self.empty_cameras): + key = f"observation.images.empty_camera_{i}" + empty_camera = PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, 480, 640), + ) + self.input_features[key] = empty_camera + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.optimizer_grad_clip_norm, + ) + + def get_scheduler_preset(self): + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + @property + def observation_delta_indices(self) -> list: + return [0] + + @property + def action_delta_indices(self) -> list: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/lerobot/common/policies/smolvla/modeling_smolvla.py b/lerobot/common/policies/smolvla/modeling_smolvla.py new file mode 100644 index 000000000..361999844 --- /dev/null +++ b/lerobot/common/policies/smolvla/modeling_smolvla.py @@ -0,0 +1,940 @@ +#!/usr/bin/env python + +# Copyright 2025 HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +SmolVLA: + +[Paper](https://huggingface.co/papers/2506.01844) + +Designed by Hugging Face. + +Install smolvla extra dependencies: +```bash +pip install -e ".[smolvla]" +``` + +Example of finetuning the smolvla pretrained model (`smolvla_base`): +```bash +python lerobot/scripts/train.py \ +--policy.path=lerobot/smolvla_base \ +--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \ +--batch_size=64 \ +--steps=200000 +``` + +Example of finetuning a smolVLA. SmolVLA is composed of a pretrained VLM, +and an action expert. +```bash +python lerobot/scripts/train.py \ +--policy.type=smolvla \ +--dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \ +--batch_size=64 \ +--steps=200000 +``` + +Example of using the smolvla pretrained model outside LeRobot training framework: +```python +policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base") +``` + +""" + +import math +import os +import re +from collections import deque + +import safetensors +import torch +import torch.nn.functional as F # noqa: N812 +from torch import Tensor, nn +from transformers import AutoProcessor + +from lerobot.common.constants import ACTION, OBS_STATE +from lerobot.common.policies.normalize import ( + Normalize, + Unnormalize, +) +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig +from lerobot.common.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel +from lerobot.common.policies.utils import ( + populate_queues, +) +from lerobot.common.utils.utils import get_safe_dtype + +# Matches ".soNNN", optionally followed by "-something", up to the "_buffer_" marker +_VARIANT_RE = re.compile(r"\.so\d+(?:-[\w]+)?_buffer_") + + +def canonicalise(k: str) -> str: + """ + Remove dataset-variant markers like '.so100-blue_' or '.so100_' from a + normalisation-buffer key. + """ + return _VARIANT_RE.sub(".buffer_", k) + + +def standardise_state_dict( + checkpoint: dict[str, torch.Tensor], ref_keys: set[str], *, verbose: bool = True +) -> tuple[dict[str, torch.Tensor], list[str]]: + """ + • Re-keys `checkpoint ` so that every entry matches the *reference* key set. + • If several variant keys collapse to the same canonical name we keep the + first one and log the collision. + • Returns the new dict + a list of entries that could not be matched. + """ + out, collisions, unmatched = {}, {}, [] + + for k, v in checkpoint.items(): + canon = canonicalise(k) + if canon in ref_keys: + if canon in out: # duplicate after collapsing + collisions.setdefault(canon, []).append(k) + else: + out[canon] = v + else: + unmatched.append(k) + + if verbose: + for canon, variants in collisions.items(): + print(f"[standardise_state_dict] '{canon}' ← {variants}") + if unmatched: + print(f"[standardise_state_dict] kept {len(unmatched)} unmatched keys") + + out.update({k: checkpoint[k] for k in unmatched}) + return out, unmatched + + +def rename_checkpoint_keys(checkpoint: dict, rename_str: str): + """ + Renames keys in a checkpoint dictionary based on the given rename string. + + Args: + checkpoint (dict): The checkpoint dictionary. + rename_str (str): A string specifying key mappings in the format "old1//new1,old2//new2". + + Returns: + dict: The modified checkpoint with renamed keys. + """ + + rename_dict = dict(pair.split("//") for pair in rename_str.split(",")) + + new_checkpoint = {} + for k, v in checkpoint.items(): + for old_key, new_key in rename_dict.items(): + if old_key in k: + k = k.replace(old_key, new_key) + new_checkpoint[k] = v + return new_checkpoint + + +def load_smolvla( + model: torch.nn.Module, + filename: str | os.PathLike, + *, + device: str = "cpu", + checkpoint_keys_mapping: str = "", +) -> torch.nn.Module: + state_dict = safetensors.torch.load_file(filename, device=device) + + # Optional user-supplied renames (e.g. "model._orig_mod.//model.") + if checkpoint_keys_mapping and "//" in checkpoint_keys_mapping: + state_dict = rename_checkpoint_keys(state_dict, checkpoint_keys_mapping) + + state_dict, _ = standardise_state_dict(state_dict, set(model.state_dict().keys())) + + # HACK(aliberts): to not overwrite normalization parameters as they should come from the dataset + norm_keys = ("normalize_inputs", "normalize_targets", "unnormalize_outputs") + state_dict = {k: v for k, v in state_dict.items() if not k.startswith(norm_keys)} + + missing, unexpected = model.load_state_dict(state_dict, strict=False) + + if not all(key.startswith(norm_keys) for key in missing) or unexpected: + raise RuntimeError( + "SmolVLA %d missing / %d unexpected keys", + len(missing), + len(unexpected), + ) + + return model + + +def create_sinusoidal_pos_embedding( + time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" +) -> Tensor: + """Computes sine-cosine positional embedding vectors for scalar positions.""" + if dimension % 2 != 0: + raise ValueError(f"dimension ({dimension}) must be divisible by 2") + + if time.ndim != 1: + raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") + + dtype = get_safe_dtype(torch.float64, device.type) + fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) + period = min_period * (max_period / min_period) ** fraction + + # Compute the outer product + scaling_factor = 1.0 / period * 2 * math.pi + sin_input = scaling_factor[None, :] * time[:, None] + pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) + return pos_emb + + +def sample_beta(alpha, beta, bsize, device): + gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha) + gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta) + return gamma1 / (gamma1 + gamma2) + + +def make_att_2d_masks(pad_masks, att_masks): + """Copied from big_vision. + + Tokens can attend to valid inputs tokens which have a cumulative mask_ar + smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to + setup several types of attention, for example: + + [[1 1 1 1 1 1]]: pure causal attention. + + [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between + themselves and the last 3 tokens have a causal attention. The first + entry could also be a 1 without changing behaviour. + + [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a + block can attend all previous blocks and all tokens on the same block. + + Args: + input_mask: bool[B, N] true if its part of the input, false if padding. + mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on + it and 0 where it shares the same attention mask as the previous token. + """ + if att_masks.ndim != 2: + raise ValueError(att_masks.ndim) + if pad_masks.ndim != 2: + raise ValueError(pad_masks.ndim) + + cumsum = torch.cumsum(att_masks, dim=1) + att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + att_2d_masks = att_2d_masks & pad_2d_masks + return att_2d_masks + + +def resize_with_pad(img, width, height, pad_value=-1): + # assume no-op when width height fits already + if img.ndim != 4: + raise ValueError(f"(b,c,h,w) expected, but {img.shape}") + + cur_height, cur_width = img.shape[2:] + + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + resized_img = F.interpolate( + img, size=(resized_height, resized_width), mode="bilinear", align_corners=False + ) + + pad_height = max(0, int(height - resized_height)) + pad_width = max(0, int(width - resized_width)) + + # pad on left and top of image + padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) + return padded_img + + +def pad_vector(vector, new_dim): + """Can be (batch_size x sequence_length x features_dimension) + or (batch_size x features_dimension) + """ + if vector.shape[-1] == new_dim: + return vector + shape = list(vector.shape) + current_dim = shape[-1] + shape[-1] = new_dim + new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device) + new_vector[..., :current_dim] = vector + return new_vector + + +def normalize(x, min_val, max_val): + return (x - min_val) / (max_val - min_val) + + +def unnormalize(x, min_val, max_val): + return x * (max_val - min_val) + min_val + + +def safe_arcsin(value): + # This ensures that the input stays within + # [−1,1] to avoid invalid values for arcsin + return torch.arcsin(torch.clamp(value, -1.0, 1.0)) + + +def aloha_gripper_to_angular(value): + # Aloha transforms the gripper positions into a linear space. The following code + # reverses this transformation to be consistent with smolvla which is pretrained in + # angular space. + # + # These values are coming from the Aloha code: + # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED + value = unnormalize(value, min_val=0.01844, max_val=0.05800) + + # This is the inverse of the angular to linear transformation inside the Interbotix code. + def linear_to_radian(linear_position, arm_length, horn_radius): + value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) + return safe_arcsin(value) + + # The constants are taken from the Interbotix code. + value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022) + + # Normalize to [0, 1]. + # The values 0.4 and 1.5 were measured on an actual Trossen robot. + return normalize(value, min_val=0.4, max_val=1.5) + + +def aloha_gripper_from_angular(value): + # Convert from the gripper position used by smolvla to the gripper position that is used by Aloha. + # Note that the units are still angular but the range is different. + + # The values 0.4 and 1.5 were measured on an actual Trossen robot. + value = unnormalize(value, min_val=0.4, max_val=1.5) + + # These values are coming from the Aloha code: + # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE + return normalize(value, min_val=-0.6213, max_val=1.4910) + + +def aloha_gripper_from_angular_inv(value): + # Directly inverts the gripper_from_angular function. + value = unnormalize(value, min_val=-0.6213, max_val=1.4910) + return normalize(value, min_val=0.4, max_val=1.5) + + +class SmolVLAPolicy(PreTrainedPolicy): + """Wrapper class around VLAFlowMatching model to train and run inference within LeRobot.""" + + config_class = SmolVLAConfig + name = "smolvla" + + def __init__( + self, + config: SmolVLAConfig, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + """ + Args: + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. + """ + + super().__init__(config) + config.validate_features() + self.config = config + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + self.language_tokenizer = AutoProcessor.from_pretrained(self.config.vlm_model_name).tokenizer + self.model = VLAFlowMatching(config) + self.reset() + + def reset(self): + """This should be called whenever the environment is reset.""" + self._queues = { + ACTION: deque(maxlen=self.config.n_action_steps), + } + + # HACK(aliberts, danaaubakirova): we overwrite this classmethod here to fix smolVLA-specific issues + @classmethod + def _load_as_safetensor( + cls, + model: "SmolVLAPolicy", + model_file: str, + map_location: str, + strict: bool, + ): + safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) + return load_smolvla( + model, + model_file, + device=map_location, + checkpoint_keys_mapping="model._orig_mod.//model.", + ) + + def get_optim_params(self) -> dict: + return self.parameters() + + def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + for k in batch: + if k in self._queues: + batch[k] = torch.stack(list(self._queues[k]), dim=1) + + images, img_masks = self.prepare_images(batch) + state = self.prepare_state(batch) + lang_tokens, lang_masks = self.prepare_language(batch) + + actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise) + + # Unpad actions + original_action_dim = self.config.action_feature.shape[0] + actions = actions[:, :, :original_action_dim] + + actions = self.unnormalize_outputs({ACTION: actions})[ACTION] + + if self.config.adapt_to_pi_aloha: + actions = self._pi_aloha_encode_actions(actions) + + return actions + + def _prepare_batch(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: + if self.config.adapt_to_pi_aloha: + batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) + + batch = self.normalize_inputs(batch) + + return batch + + def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + self.eval() + + batch = self._prepare_batch(batch) + self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) + + actions = self._get_action_chunk(batch, noise) + return actions + + @torch.no_grad + def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + """Select a single action given environment observations. + + This method wraps `select_actions` in order to return one action at a time for execution in the + environment. It works by managing the actions in a queue and only calling `select_actions` when the + queue is empty. + """ + self.eval() + batch = self._prepare_batch(batch) + self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION]) + + # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by + # querying the policy. + if len(self._queues[ACTION]) == 0: + actions = self._get_action_chunk(batch, noise) + + # `self.predict_action_chunk` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue + # effectively has shape (n_action_steps, batch_size, *), hence the transpose. + self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps]) + + return self._queues[ACTION].popleft() + + def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]: + """Do a full training forward pass to compute the loss""" + if self.config.adapt_to_pi_aloha: + batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) + batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) + batch = self.normalize_inputs(batch) + batch = self.normalize_targets(batch) + images, img_masks = self.prepare_images(batch) + state = self.prepare_state(batch) + lang_tokens, lang_masks = self.prepare_language(batch) + actions = self.prepare_action(batch) + actions_is_pad = batch.get("actions_id_pad") + loss_dict = {} + losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) + loss_dict["losses_after_forward"] = losses.clone() + + if actions_is_pad is not None: + in_episode_bound = ~actions_is_pad + losses = losses * in_episode_bound.unsqueeze(-1) + loss_dict["losses_after_in_ep_bound"] = losses.clone() + + # Remove padding + losses = losses[:, :, : self.config.max_action_dim] + loss_dict["losses_after_rm_padding"] = losses.clone() + + # For backward pass + loss = losses.mean() + # For backward pass + loss_dict["loss"] = loss.item() + return loss, loss_dict + + def prepare_images(self, batch): + """Apply SmolVLA preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and + convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP. + """ + images = [] + img_masks = [] + present_img_keys = [key for key in self.config.image_features if key in batch] + missing_img_keys = [key for key in self.config.image_features if key not in batch] + + if len(present_img_keys) == 0: + raise ValueError( + f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})" + ) + # Preprocess image features present in the batch + for key in present_img_keys: + img = batch[key][:, -1, :, :, :] if batch[key].ndim == 5 else batch[key] + if self.config.resize_imgs_with_padding is not None: + img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0) + + # Normalize from range [0,1] to [-1,1] as expacted by siglip + img = img * 2.0 - 1.0 + + bsize = img.shape[0] + device = img.device + if f"{key}_padding_mask" in batch: + mask = batch[f"{key}_padding_mask"].bool() + else: + mask = torch.ones(bsize, dtype=torch.bool, device=device) + images.append(img) + img_masks.append(mask) + + # Create image features not present in the batch + # as fully 0 padded images. + for num_empty_cameras in range(len(missing_img_keys)): + if num_empty_cameras >= self.config.empty_cameras: + break + img = torch.ones_like(img) * -1 + mask = torch.zeros_like(mask) + images.append(img) + img_masks.append(mask) + return images, img_masks + + def prepare_language(self, batch) -> tuple[Tensor, Tensor]: + """Tokenize the text input""" + device = batch[OBS_STATE].device + tasks = batch["task"] + if isinstance(tasks, str): + tasks = [tasks] + + if len(tasks) == 1: + tasks = [tasks[0] for _ in range(batch[OBS_STATE].shape[0])] + + tasks = [task if task.endswith("\n") else f"{task}\n" for task in tasks] + + tokenized_prompt = self.language_tokenizer.__call__( + tasks, + padding=self.config.pad_language_to, + padding_side="right", + max_length=self.config.tokenizer_max_length, + return_tensors="pt", + ) + lang_tokens = tokenized_prompt["input_ids"].to(device=device) + lang_masks = tokenized_prompt["attention_mask"].to(device=device, dtype=torch.bool) + + return lang_tokens, lang_masks + + def _pi_aloha_decode_state(self, state): + # Flip the joints. + for motor_idx in [1, 2, 8, 9]: + state[:, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx]) + return state + + def _pi_aloha_encode_actions(self, actions): + # Flip the joints. + for motor_idx in [1, 2, 8, 9]: + actions[:, :, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx]) + return actions + + def _pi_aloha_encode_actions_inv(self, actions): + # Flip the joints again. + for motor_idx in [1, 2, 8, 9]: + actions[:, :, motor_idx] *= -1 + # Reverse the gripper transformation that is being applied by the Aloha runtime. + for motor_idx in [6, 13]: + actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) + return actions + + def prepare_state(self, batch): + """Pad state""" + state = batch[OBS_STATE][:, -1, :] if batch[OBS_STATE].ndim > 2 else batch[OBS_STATE] + state = pad_vector(state, self.config.max_state_dim) + return state + + def prepare_action(self, batch): + """Pad action""" + actions = pad_vector(batch[ACTION], self.config.max_action_dim) + return actions + + +def pad_tensor(tensor, max_len, pad_value=0): + """ + Efficiently pads a tensor along sequence dimension to match max_len. + + Args: + tensor (torch.Tensor): Shape (B, L, ...) or (B, L). + max_len (int): Fixed sequence length. + pad_value (int/float): Value for padding. + + Returns: + torch.Tensor: Shape (B, max_len, ...) or (B, max_len). + """ + b, d = tensor.shape[:2] + + # Create a padded tensor of max_len and copy the existing values + padded_tensor = torch.full( + (b, max_len, *tensor.shape[2:]), pad_value, dtype=tensor.dtype, device=tensor.device + ) + padded_tensor[:, :d] = tensor # Efficient in-place copy + + return padded_tensor + + +class VLAFlowMatching(nn.Module): + """ + SmolVLA + + [Paper]() + + Designed by Hugging Face. + ┌──────────────────────────────┐ + │ actions │ + │ ▲ │ + │ ┌─────────┐ ┌─|────┐ │ + │ | │────► │ │ │ + │ | │ kv │ │ │ + │ | │────► │Action│ │ + │ | VLM │cache │Expert│ | + │ │ │────► | │ │ + │ │ │ │ │ │ + │ └▲──▲───▲─┘ └───▲──┘ | + │ │ | | │ | + │ | | | noise │ + │ │ │ state │ + │ │ language tokens │ + │ image(s) │ + └──────────────────────────────┘ + """ + + def __init__(self, config): + super().__init__() + self.config = config + + self.vlm_with_expert = SmolVLMWithExpertModel( + model_id=self.config.vlm_model_name, + freeze_vision_encoder=self.config.freeze_vision_encoder, + train_expert_only=self.config.train_expert_only, + load_vlm_weights=self.config.load_vlm_weights, + attention_mode=self.config.attention_mode, + num_expert_layers=self.config.num_expert_layers, + num_vlm_layers=self.config.num_vlm_layers, + self_attn_every_n_layers=self.config.self_attn_every_n_layers, + expert_width_multiplier=self.config.expert_width_multiplier, + ) + self.state_proj = nn.Linear( + self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size + ) + self.action_in_proj = nn.Linear(self.config.max_action_dim, self.vlm_with_expert.expert_hidden_size) + self.action_out_proj = nn.Linear(self.vlm_with_expert.expert_hidden_size, self.config.max_action_dim) + + self.action_time_mlp_in = nn.Linear( + self.vlm_with_expert.expert_hidden_size * 2, self.vlm_with_expert.expert_hidden_size + ) + self.action_time_mlp_out = nn.Linear( + self.vlm_with_expert.expert_hidden_size, self.vlm_with_expert.expert_hidden_size + ) + + self.set_requires_grad() + self.fake_image_token = self.vlm_with_expert.processor.tokenizer.fake_image_token_id + self.global_image_token = self.vlm_with_expert.processor.tokenizer.global_image_token_id + self.global_image_start_token = torch.tensor( + [self.fake_image_token, self.global_image_token], dtype=torch.long + ) + + self.add_image_special_tokens = self.config.add_image_special_tokens + self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long) + self.prefix_length = self.config.prefix_length + + def set_requires_grad(self): + for params in self.state_proj.parameters(): + params.requires_grad = self.config.train_state_proj + + def sample_noise(self, shape, device): + noise = torch.normal( + mean=0.0, + std=1.0, + size=shape, + dtype=torch.float32, + device=device, + ) + return noise + + def sample_time(self, bsize, device): + time_beta = sample_beta(1.5, 1.0, bsize, device) + time = time_beta * 0.999 + 0.001 + return time.to(dtype=torch.float32, device=device) + + def embed_prefix( + self, images, img_masks, lang_tokens, lang_masks, state: torch.Tensor = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Embed images with SigLIP and language tokens with embedding layer to prepare + for SmolVLM transformer processing. + """ + embs = [] + pad_masks = [] + att_masks = [] + for _img_idx, ( + img, + img_mask, + ) in enumerate(zip(images, img_masks, strict=False)): + if self.add_image_special_tokens: + image_start_token = ( + self.vlm_with_expert.embed_language_tokens( + self.global_image_start_token.to(device=self.vlm_with_expert.vlm.device) + ) + .unsqueeze(0) + .expand(img.shape[0], -1, -1) + ) + image_start_mask = torch.ones_like( + image_start_token[:, :, 0], dtype=torch.bool, device=image_start_token.device + ) + att_masks += [0] * (image_start_mask.shape[-1]) + embs.append(image_start_token) + pad_masks.append(image_start_mask) + + img_emb = self.vlm_with_expert.embed_image(img) + img_emb = img_emb + + # Normalize image embeddings + img_emb_dim = img_emb.shape[-1] + img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device) + + bsize, num_img_embs = img_emb.shape[:2] + img_mask = img_mask[:, None].expand(bsize, num_img_embs) + + embs.append(img_emb) + pad_masks.append(img_mask) + + att_masks += [0] * (num_img_embs) + if self.add_image_special_tokens: + image_end_token = ( + self.vlm_with_expert.embed_language_tokens( + self.image_end_token.to(device=self.vlm_with_expert.vlm.device) + ) + .unsqueeze(0) + .expand(img.shape[0], -1, -1) + ) + image_end_mask = torch.ones_like( + image_end_token[:, :, 0], dtype=torch.bool, device=image_end_token.device + ) + embs.append(image_end_token) + pad_masks.append(image_end_mask) + att_masks += [0] * (image_end_mask.shape[1]) + lang_emb = self.vlm_with_expert.embed_language_tokens(lang_tokens) + # Normalize language embeddings + lang_emb_dim = lang_emb.shape[-1] + lang_emb = lang_emb * math.sqrt(lang_emb_dim) + + embs.append(lang_emb) + pad_masks.append(lang_masks) + + num_lang_embs = lang_emb.shape[1] + att_masks += [0] * num_lang_embs + + state_emb = self.state_proj(state) + state_emb = state_emb[:, None, :] if state_emb.ndim == 2 else state_emb + embs.append(state_emb) + bsize = state_emb.shape[0] + device = state_emb.device + + states_seq_len = state_emb.shape[1] + state_mask = torch.ones(bsize, states_seq_len, dtype=torch.bool, device=device) + pad_masks.append(state_mask) + + # Set attention masks so that image and language inputs do not attend to state or actions + att_masks += [1] * (states_seq_len) + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) + att_masks = att_masks[None, :] + + seq_len = pad_masks.shape[1] + if seq_len < self.prefix_length: + embs = pad_tensor(embs, self.prefix_length, pad_value=0) + pad_masks = pad_tensor(pad_masks, self.prefix_length, pad_value=0) + att_masks = pad_tensor(att_masks, self.prefix_length, pad_value=0) + + att_masks = att_masks.expand(bsize, -1) + + return embs, pad_masks, att_masks + + def embed_suffix(self, noisy_actions, timestep): + """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.""" + embs = [] + pad_masks = [] + att_masks = [] + + # Fuse timestep + action information using an MLP + action_emb = self.action_in_proj(noisy_actions) + device = action_emb.device + bsize = action_emb.shape[0] + dtype = action_emb.dtype + # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] + time_emb = create_sinusoidal_pos_embedding( + timestep, + self.vlm_with_expert.expert_hidden_size, + self.config.min_period, + self.config.max_period, + device=device, + ) + time_emb = time_emb.type(dtype=dtype) + + time_emb = time_emb[:, None, :].expand_as(action_emb) + action_time_emb = torch.cat([action_emb, time_emb], dim=2) + + action_time_emb = self.action_time_mlp_in(action_time_emb) + action_time_emb = F.silu(action_time_emb) # swish == silu + action_time_emb = self.action_time_mlp_out(action_time_emb) + + # Add to input tokens + embs.append(action_time_emb) + + bsize, action_time_dim = action_time_emb.shape[:2] + action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device) + pad_masks.append(action_time_mask) + + # Set attention masks so that image, language and state inputs do not attend to action tokens + att_masks += [1] * self.config.chunk_size + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) + att_masks = att_masks[None, :].expand(bsize, len(att_masks)) + return embs, pad_masks, att_masks + + def forward( + self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None + ) -> Tensor: + """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" + if noise is None: + noise = self.sample_noise(actions.shape, actions.device) + + if time is None: + time = self.sample_time(actions.shape[0], actions.device) + + time_expanded = time[:, None, None] + x_t = time_expanded * noise + (1 - time_expanded) * actions + u_t = noise - actions + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( + images, img_masks, lang_tokens, lang_masks, state=state + ) + suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, time) + + pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) + att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) + + att_2d_masks = make_att_2d_masks(pad_masks, att_masks) + position_ids = torch.cumsum(pad_masks, dim=1) - 1 + (_, suffix_out), _ = self.vlm_with_expert.forward( + attention_mask=att_2d_masks, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, suffix_embs], + use_cache=False, + fill_kv_cache=False, + ) + suffix_out = suffix_out[:, -self.config.chunk_size :] + # Original openpi code, upcast attention output + suffix_out = suffix_out.to(dtype=torch.float32) + v_t = self.action_out_proj(suffix_out) + losses = F.mse_loss(u_t, v_t, reduction="none") + return losses + + def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor: + """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" + bsize = state.shape[0] + device = state.device + + if noise is None: + actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim) + noise = self.sample_noise(actions_shape, device) + + prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( + images, img_masks, lang_tokens, lang_masks, state=state + ) + prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) + prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + # Compute image and language key value cache + _, past_key_values = self.vlm_with_expert.forward( + attention_mask=prefix_att_2d_masks, + position_ids=prefix_position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=self.config.use_cache, + fill_kv_cache=True, + ) + dt = -1.0 / self.config.num_steps + dt = torch.tensor(dt, dtype=torch.float32, device=device) + + x_t = noise + time = torch.tensor(1.0, dtype=torch.float32, device=device) + while time >= -dt / 2: + expanded_time = time.expand(bsize) + v_t = self.denoise_step( + prefix_pad_masks, + past_key_values, + x_t, + expanded_time, + ) + # Euler step + x_t += dt * v_t + time += dt + return x_t + + def denoise_step( + self, + prefix_pad_masks, + past_key_values, + x_t, + timestep, + ): + """Apply one denoising step of the noise `x_t` at a given timestep.""" + suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, timestep) + + suffix_len = suffix_pad_masks.shape[1] + batch_size = prefix_pad_masks.shape[0] + prefix_len = prefix_pad_masks.shape[1] + prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) + + suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) + + full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) + prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] + position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 + + outputs_embeds, _ = self.vlm_with_expert.forward( + attention_mask=full_att_2d_masks, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=[None, suffix_embs], + use_cache=self.config.use_cache, + fill_kv_cache=False, + ) + suffix_out = outputs_embeds[1] + suffix_out = suffix_out[:, -self.config.chunk_size :] + suffix_out = suffix_out.to(dtype=torch.float32) + v_t = self.action_out_proj(suffix_out) + return v_t diff --git a/lerobot/common/policies/smolvla/smolvlm_with_expert.py b/lerobot/common/policies/smolvla/smolvlm_with_expert.py new file mode 100644 index 000000000..07eae8089 --- /dev/null +++ b/lerobot/common/policies/smolvla/smolvlm_with_expert.py @@ -0,0 +1,550 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +from typing import List, Optional + +import torch +from torch import nn +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForImageTextToText, + AutoProcessor, + SmolVLMForConditionalGeneration, +) + + +def apply_rope(x, positions, max_wavelength=10_000): + """ + Applies RoPE positions [B, L] to x [B, L, H, D]. + """ + d_half = x.shape[-1] // 2 + device = x.device + dtype = x.dtype + x = x.to(torch.float32) + + freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device) + timescale = max_wavelength**freq_exponents + radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32) + + radians = radians[..., None, :] + + sin = torch.sin(radians) # .to(dtype=dtype) + cos = torch.cos(radians) # .to(dtype=dtype) + + x1, x2 = x.split(d_half, dim=-1) + res = torch.empty_like(x) + res[..., :d_half] = x1 * cos - x2 * sin + res[..., d_half:] = x2 * cos + x1 * sin + + return res.to(dtype) + + +def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256): + hidden_dim = int(2 * hidden_dim / 3) + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + return hidden_dim + + +class SmolVLMWithExpertModel(nn.Module): + def __init__( + self, + model_id: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct", + load_vlm_weights: bool = True, + train_expert_only: bool = True, + freeze_vision_encoder: bool = False, + attention_mode: str = "self_attn", + num_expert_layers: int = -1, + num_vlm_layers: int = -1, + self_attn_every_n_layers: int = -1, + expert_width_multiplier: float = 0.5, + ): + super().__init__() + if load_vlm_weights: + print(f"Loading {model_id} weights ...") + self.vlm = AutoModelForImageTextToText.from_pretrained( + model_id, + device_map="auto", + torch_dtype="bfloat16", + low_cpu_mem_usage=True, + ) + config = self.vlm.config + else: + config = AutoConfig.from_pretrained(model_id) + self.vlm = SmolVLMForConditionalGeneration(config=config) + self.processor = AutoProcessor.from_pretrained(model_id) + if num_vlm_layers > 0: + print(f"Reducing the number of VLM layers to {num_vlm_layers} ...") + self.get_vlm_model().text_model.layers = self.get_vlm_model().text_model.layers[:num_vlm_layers] + self.num_vlm_layers = len(self.get_vlm_model().text_model.layers) + self.config = config + # Smaller lm expert + lm_expert_config = copy.deepcopy(config.text_config) + hidden_size = lm_expert_config.hidden_size + lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2 + lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier)) + lm_expert_config.num_hidden_layers = self.num_vlm_layers + if num_expert_layers > 0: + assert len(self.get_vlm_model().text_model.layers) % num_expert_layers == 0, ( + f"Number of layers in the VLM {len(self.get_vlm_model().text_model.layers)} are not multiple of num_expert_layers {num_expert_layers}" + ) + lm_expert_config.num_hidden_layers = num_expert_layers + self.lm_expert = AutoModel.from_config(lm_expert_config) + + self.num_expert_layers = len(self.lm_expert.layers) + self.self_attn_every_n_layers = self_attn_every_n_layers + if "cross" in attention_mode: + # Reshape qkv projections to have the same input dimension as the vlm + for layer_idx in range(len(self.lm_expert.layers)): + if self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0: + continue + self.lm_expert.layers[layer_idx].self_attn.k_proj = nn.Linear( + config.text_config.num_key_value_heads * config.text_config.head_dim, + lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, + bias=lm_expert_config.attention_bias, + ) + self.lm_expert.layers[layer_idx].self_attn.v_proj = nn.Linear( + config.text_config.num_key_value_heads * config.text_config.head_dim, + lm_expert_config.num_key_value_heads * lm_expert_config.head_dim, + bias=lm_expert_config.attention_bias, + ) + # Remove unused embed_tokens + self.lm_expert.embed_tokens = None + + self.num_attention_heads = self.config.text_config.num_attention_heads + self.num_key_value_heads = self.config.text_config.num_key_value_heads + + self.freeze_vision_encoder = freeze_vision_encoder + self.train_expert_only = train_expert_only + self.attention_mode = attention_mode + self.expert_hidden_size = lm_expert_config.hidden_size + self.set_requires_grad() + + def get_vlm_model(self): + return self.vlm.model + + def set_requires_grad(self): + if self.freeze_vision_encoder: + self.get_vlm_model().vision_model.eval() + for params in self.get_vlm_model().vision_model.parameters(): + params.requires_grad = False + if self.train_expert_only: + self.vlm.eval() + for params in self.vlm.parameters(): + params.requires_grad = False + else: + # To avoid unused params issue with distributed training + last_layers = [self.num_vlm_layers - 1] + if ( + self.num_vlm_layers != self.num_expert_layers + and self.num_vlm_layers % self.num_expert_layers == 0 + ): + last_layers.append(self.num_vlm_layers - 2) + frozen_layers = [ + "lm_head", + "text_model.model.norm.weight", + ] + for layer in last_layers: + frozen_layers.append(f"text_model.model.layers.{layer}.") + + for name, params in self.vlm.named_parameters(): + if any(k in name for k in frozen_layers): + params.requires_grad = False + # To avoid unused params issue with distributed training + for name, params in self.lm_expert.named_parameters(): + if "lm_head" in name: + params.requires_grad = False + + def train(self, mode: bool = True): + super().train(mode) + + if self.freeze_vision_encoder: + self.get_vlm_model().vision_model.eval() + + if self.train_expert_only: + self.vlm.eval() + + def embed_image(self, image: torch.Tensor): + patch_attention_mask = None + # Get sequence from the vision encoder + image_hidden_states = ( + self.get_vlm_model() + .vision_model( + pixel_values=image.to(dtype=self.get_vlm_model().vision_model.dtype), + patch_attention_mask=patch_attention_mask, + ) + .last_hidden_state + ) + # Modality projection & resampling + image_hidden_states = self.get_vlm_model().connector(image_hidden_states) + return image_hidden_states + + def embed_language_tokens(self, tokens: torch.Tensor): + return self.get_vlm_model().text_model.get_input_embeddings()(tokens) + + def forward_attn_layer( + self, + model_layers, + inputs_embeds, + layer_idx, + position_ids, + attention_mask, + batch_size, + head_dim, + use_cache: bool = True, + fill_kv_cache: bool = True, + past_key_values=None, + ) -> list[torch.Tensor]: + query_states = [] + key_states = [] + value_states = [] + for i, hidden_states in enumerate(inputs_embeds): + layer = model_layers[i][layer_idx] + if hidden_states is None or layer is None: + continue + hidden_states = layer.input_layernorm(hidden_states) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + + hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) + value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape) + + query_states.append(query_state) + key_states.append(key_state) + value_states.append(value_state) + + # B,L,H,D with L sequence length, H number of heads, D head dim + # concatenate on the number of embeddings/tokens + query_states = torch.cat(query_states, dim=1) + key_states = torch.cat(key_states, dim=1) + value_states = torch.cat(value_states, dim=1) + seq_len = query_states.shape[1] + if seq_len < position_ids.shape[1]: + _position_ids = position_ids[:, :seq_len] + _attention_mask = attention_mask[:, :seq_len, :seq_len] + else: + _position_ids = position_ids + _attention_mask = attention_mask + + attention_mask_ = _attention_mask + position_ids_ = _position_ids + + query_states = apply_rope(query_states, position_ids_) + key_states = apply_rope(key_states, position_ids_) + + if use_cache and past_key_values is None: + past_key_values = {} + + if use_cache: + if fill_kv_cache: + past_key_values[layer_idx] = { + "key_states": key_states, + "value_states": value_states, + } + else: + # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before. + # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach + # the max len, then we (for instance) double the cache size. This implementation already exists + # in `transformers`. (molbap) + key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1) + value_states = torch.cat([past_key_values[layer_idx]["value_states"], value_states], dim=1) + + attention_interface = self.get_attention_interface() + + att_output = attention_interface( + attention_mask_, batch_size, head_dim, query_states, key_states, value_states + ) + return [att_output], past_key_values + + def forward_cross_attn_layer( + self, + model_layers, + inputs_embeds, + layer_idx, + position_ids, + attention_mask, + batch_size, + head_dim, + use_cache: bool = True, + fill_kv_cache: bool = True, + past_key_values=None, + ) -> list[torch.Tensor]: + attention_interface = self.get_attention_interface() + + att_outputs = [] + assert len(inputs_embeds) == 2 or (use_cache and past_key_values is not None and not fill_kv_cache), ( + f"Both len(inputs_embeds) == {len(inputs_embeds)} and past_key_values is {past_key_values}" + ) + + if len(inputs_embeds) == 2 and not past_key_values: + # Prefix attention + seq_len = inputs_embeds[0].shape[1] + position_id, expert_position_id = position_ids[:, :seq_len], position_ids[:, seq_len:] + prefix_attention_mask = attention_mask[:, :seq_len, :seq_len] + + layer = model_layers[0][layer_idx] + + hidden_states = layer.input_layernorm(inputs_embeds[0]) + + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) + + hidden_states = hidden_states.to(dtype=layer.self_attn.q_proj.weight.dtype) + query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) + key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) + value_states = layer.self_attn.v_proj(hidden_states).view(hidden_shape) + + # B,L,H,D with L sequence length, H number of heads, D head dim + query_states = apply_rope(query_state, position_id) + key_states = apply_rope(key_state, position_id) + + att_output = attention_interface( + prefix_attention_mask, batch_size, head_dim, query_states, key_states, value_states + ) + att_outputs.append(att_output) + else: + expert_position_id = position_ids + + if use_cache and past_key_values is None: + past_key_values = {} + + if use_cache: + if fill_kv_cache: + past_key_values[layer_idx] = { + "key_states": key_states, + "value_states": value_states, + } + else: + # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before. + # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach + # the max len, then we (for instance) double the cache size. This implementation already exists + # in `transformers`. (molbap) + key_states = past_key_values[layer_idx]["key_states"] + value_states = past_key_values[layer_idx]["value_states"] + + # Expert + expert_layer = model_layers[1][layer_idx] + if expert_layer is not None: + expert_hidden_states = expert_layer.input_layernorm(inputs_embeds[1]) + + expert_input_shape = expert_hidden_states.shape[:-1] + expert_hidden_shape = (*expert_input_shape, -1, expert_layer.self_attn.head_dim) + + expert_hidden_states = expert_hidden_states.to(dtype=expert_layer.self_attn.q_proj.weight.dtype) + expert_query_state = expert_layer.self_attn.q_proj(expert_hidden_states).view(expert_hidden_shape) + + _key_states = key_states.to(dtype=expert_layer.self_attn.k_proj.weight.dtype).view( + *key_states.shape[:2], -1 + ) + expert_key_states = expert_layer.self_attn.k_proj(_key_states).view( + *_key_states.shape[:-1], -1, expert_layer.self_attn.head_dim + ) # k_proj should have same dim as kv + + _value_states = value_states.to(dtype=expert_layer.self_attn.v_proj.weight.dtype).view( + *value_states.shape[:2], -1 + ) + expert_value_states = expert_layer.self_attn.v_proj(_value_states).view( + *_value_states.shape[:-1], -1, expert_layer.self_attn.head_dim + ) + + expert_position_id = ( + expert_position_id - torch.min(expert_position_id, dim=1, keepdim=True).values + ) # start from 0 + expert_attention_mask = attention_mask[ + :, -inputs_embeds[1].shape[1] :, : expert_key_states.shape[1] : + ] # take into account kv + + expert_query_states = apply_rope(expert_query_state, expert_position_id) + + att_output = attention_interface( + expert_attention_mask, + batch_size, + head_dim, + expert_query_states, + expert_key_states, + expert_value_states, + ) + att_outputs.append(att_output) + else: + att_outputs.append(None) + + # att_output = att_output.to(dtype=models[i].dtype) + return att_outputs, past_key_values + + def get_model_layers(self, models: list) -> list: + vlm_layers = [] + expert_layers = [] + multiple_of = self.num_vlm_layers // self.num_expert_layers + for i in range(self.num_vlm_layers): + if multiple_of > 0 and i > 0 and i % multiple_of != 0: + expert_layer = None + else: + expert_layer_index = i // multiple_of if multiple_of > 0 else i + expert_layer = models[1].layers[expert_layer_index] + vlm_layers.append(models[0].layers[i]) + expert_layers.append(expert_layer) + return [vlm_layers, expert_layers] + + def forward( + self, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: List[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + fill_kv_cache: Optional[bool] = None, + ): + models = [self.get_vlm_model().text_model, self.lm_expert] + model_layers = self.get_model_layers(models) + for hidden_states in inputs_embeds: + # TODO this is very inefficient + # dtype is always the same, batch size too (if > 1 len) + # device could be trickier in multi gpu edge cases but that's it + if hidden_states is None: + continue + batch_size = hidden_states.shape[0] + + # RMSNorm + num_layers = self.num_vlm_layers + head_dim = self.vlm.config.text_config.head_dim + for layer_idx in range(num_layers): + if ( + fill_kv_cache + or "cross" not in self.attention_mode + or (self.self_attn_every_n_layers > 0 and layer_idx % self.self_attn_every_n_layers == 0) + ): + att_outputs, past_key_values = self.forward_attn_layer( + model_layers, + inputs_embeds, + layer_idx, + position_ids, + attention_mask, + batch_size, + head_dim, + use_cache=use_cache, + fill_kv_cache=fill_kv_cache, + past_key_values=past_key_values, + ) + else: + att_outputs, past_key_values = self.forward_cross_attn_layer( + model_layers, + inputs_embeds, + layer_idx, + position_ids, + attention_mask, + batch_size, + head_dim, + use_cache=use_cache, + fill_kv_cache=fill_kv_cache, + past_key_values=past_key_values, + ) + outputs_embeds = [] + start = 0 + for i, hidden_states in enumerate(inputs_embeds): + layer = model_layers[i][layer_idx] + att_output = ( + att_outputs[i] if i < len(att_outputs) else att_outputs[0] + ) # in case of self_attn + if hidden_states is not None: + if layer is None: + outputs_embeds.append(hidden_states) + continue + end = start + hidden_states.shape[1] + + if att_output.dtype != layer.self_attn.o_proj.weight.dtype: + att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) + att_out = att_output[:, start:end] + out_emb = layer.self_attn.o_proj(att_out) + + out_emb += hidden_states + after_first_residual = out_emb.clone() + + out_emb = layer.post_attention_layernorm(out_emb) + out_emb = layer.mlp(out_emb) + + out_emb += after_first_residual + + outputs_embeds.append(out_emb) + + start = end if len(att_outputs) == 1 else 0 + else: + outputs_embeds.append(None) + + inputs_embeds = outputs_embeds + + # final norm + outputs_embeds = [] + for i, hidden_states in enumerate(inputs_embeds): + if hidden_states is not None: + out_emb = models[i].norm(hidden_states) + outputs_embeds.append(out_emb) + else: + outputs_embeds.append(None) + return outputs_embeds, past_key_values + + def get_attention_interface(self): + attention_interface = self.eager_attention_forward + return attention_interface + + def eager_attention_forward( + self, attention_mask, batch_size, head_dim, query_states, key_states, value_states + ): + num_att_heads = self.num_attention_heads + num_key_value_heads = self.num_key_value_heads + num_key_value_groups = num_att_heads // num_key_value_heads + + sequence_length = key_states.shape[1] + + key_states = key_states[:, :, :, None, :].expand( + batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim + ) + key_states = key_states.reshape( + batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim + ) + + value_states = value_states[:, :, :, None, :].expand( + batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim + ) + value_states = value_states.reshape( + batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim + ) + + # Attention here is upcasted to float32 to match the original eager implementation. + query_states = query_states.to(dtype=torch.float32) + key_states = key_states.to(dtype=torch.float32) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + + att_weights = torch.matmul(query_states, key_states.transpose(2, 3)) + att_weights *= head_dim**-0.5 + + att_weights = att_weights.to(dtype=torch.float32) + big_neg = torch.finfo(att_weights.dtype).min # -2.3819763e38 # See gemma/modules.py + masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg) + probs = nn.functional.softmax(masked_att_weights, dim=-1) + probs = probs.to(dtype=value_states.dtype) + + att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3)) + + att_output = att_output.permute(0, 2, 1, 3) + # we use -1 because sequence length can change + att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim) + + return att_output diff --git a/lerobot/common/policies/utils.py b/lerobot/common/policies/utils.py index c06e620ba..5659e8727 100644 --- a/lerobot/common/policies/utils.py +++ b/lerobot/common/policies/utils.py @@ -14,15 +14,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import deque + import torch from torch import nn -def populate_queues(queues, batch): +def populate_queues( + queues: dict[str, deque], batch: dict[str, torch.Tensor], exclude_keys: list[str] | None = None +): + if exclude_keys is None: + exclude_keys = [] for key in batch: # Ignore keys not in the queues already (leaving the responsibility to the caller to make sure the # queues have the keys they want). - if key not in queues: + if key not in queues or key in exclude_keys: continue if len(queues[key]) != queues[key].maxlen: # initialize by copying the first observation several times until the queue is full diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index 97a08e2f4..a76bea2ab 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -27,6 +27,7 @@ import torch.nn.functional as F # noqa: N812 import torchvision from torch import Tensor, nn +from lerobot.common.constants import ACTION, OBS_IMAGES, OBS_STATE from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.utils import get_device_from_parameters, get_output_shape, populate_queues @@ -118,11 +119,18 @@ class VQBeTPolicy(PreTrainedPolicy): queues are populated during rollout of the policy, they contain the n latest observations and actions """ self._queues = { - "observation.images": deque(maxlen=self.config.n_obs_steps), - "observation.state": deque(maxlen=self.config.n_obs_steps), - "action": deque(maxlen=self.config.action_chunk_size), + OBS_IMAGES: deque(maxlen=self.config.n_obs_steps), + OBS_STATE: deque(maxlen=self.config.n_obs_steps), + ACTION: deque(maxlen=self.config.action_chunk_size), } + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} + actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size] + actions = self.unnormalize_outputs({ACTION: actions})[ACTION] + return actions + @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. @@ -144,31 +152,27 @@ class VQBeTPolicy(PreTrainedPolicy): stacklevel=1, ) - if len(self._queues["action"]) == 0: - batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} - actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size] - - # the dimension of returned action is (batch_size, action_chunk_size, action_dim) - actions = self.unnormalize_outputs({"action": actions})["action"] + if len(self._queues[ACTION]) == 0: + actions = self.predict_action_chunk(batch) # since the data in the action queue's dimension is (action_chunk_size, batch_size, action_dim), we transpose the action and fill the queue - self._queues["action"].extend(actions.transpose(0, 1)) + self._queues[ACTION].extend(actions.transpose(0, 1)) - action = self._queues["action"].popleft() + action = self._queues[ACTION].popleft() return action def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) batch = self.normalize_targets(batch) - # VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181) + # VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://huggingface.co/papers/2403.03181) if not self.vqbet.action_head.vqvae_model.discretized.item(): # loss: total loss of training RVQ # n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`. # n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree). loss, n_different_codes, n_different_combinations, recon_l1_error = ( - self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"]) + self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch[ACTION]) ) return loss, { "n_different_codes": n_different_codes, @@ -185,7 +189,7 @@ class VQBeTPolicy(PreTrainedPolicy): class SpatialSoftmax(nn.Module): """ Spatial Soft Argmax operation described in "Deep Spatial Autoencoders for Visuomotor Learning" by Finn et al. - (https://arxiv.org/pdf/1509.06113). A minimal port of the robomimic implementation. + (https://huggingface.co/papers/1509.06113). A minimal port of the robomimic implementation. At a high level, this takes 2D feature maps (from a convnet/ViT) and returns the "center of mass" of activations of each channel, i.e., keypoints in the image space for the policy to focus on. @@ -387,7 +391,7 @@ class VQBeTModel(nn.Module): # only extract the output tokens at the position of action query: # Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models, - # mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://arxiv.org/pdf/2206.11251). + # mapping sequential observation to sequential action (please refer to section 2.2 in BeT paper https://huggingface.co/papers/2206.11251). # Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional). if len_additional_action_token > 0: features = torch.cat( @@ -404,7 +408,7 @@ class VQBeTModel(nn.Module): ) # else, it calculate overall loss (bin prediction loss, and offset loss) else: - output = batch["action"][:, self.select_target_actions_indices] + output = batch[ACTION][:, self.select_target_actions_indices] loss = self.action_head.loss_fn(action_head_output, output, reduction="mean") return action_head_output, loss @@ -824,8 +828,8 @@ class VqVae(nn.Module): return einops.rearrange(output, "N (T A) -> N T A", A=self.config.action_feature.shape[0]) def get_code(self, state): - # in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181) - # this function outputs the `GT code` of given action using frozen encoder and quantization layers. (please refer to Figure 2. in the paper https://arxiv.org/pdf/2403.03181) + # in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://huggingface.co/papers/2403.03181) + # this function outputs the `GT code` of given action using frozen encoder and quantization layers. (please refer to Figure 2. in the paper https://huggingface.co/papers/2403.03181) state = einops.rearrange(state, "N T A -> N (T A)") with torch.no_grad(): state_rep = self.encoder(state) @@ -838,7 +842,7 @@ class VqVae(nn.Module): return state_vq, vq_code def vqvae_forward(self, state): - # This function passes the given data through Residual VQ with Encoder and Decoder. Please refer to section 3.2 in the paper https://arxiv.org/pdf/2403.03181). + # This function passes the given data through Residual VQ with Encoder and Decoder. Please refer to section 3.2 in the paper https://huggingface.co/papers/2403.03181). state = einops.rearrange(state, "N T A -> N (T A)") # We start with passing action (or action chunk) at:t+n through the encoder ϕ. state_rep = self.encoder(state) diff --git a/lerobot/common/policies/vqbet/vqbet_utils.py b/lerobot/common/policies/vqbet/vqbet_utils.py index 139d119ed..09a86c07b 100644 --- a/lerobot/common/policies/vqbet/vqbet_utils.py +++ b/lerobot/common/policies/vqbet/vqbet_utils.py @@ -336,7 +336,7 @@ class ResidualVQ(nn.Module): """ Residual VQ is composed of multiple VectorQuantize layers. - Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + Follows Algorithm 1. in https://huggingface.co/papers/2107.03312 "Residual Vector Quantizer (a.k.a. multi-stage vector quantizer [36]) cascades Nq layers of VQ as follows. The unquantized input vector is passed through a first VQ and quantization residuals are computed. The residuals are then iteratively quantized by a sequence of additional Nq -1 vector quantizers, as described in Algorithm 1." @@ -1006,7 +1006,7 @@ def gumbel_sample( if not straight_through or temperature <= 0.0 or not training: return ind, one_hot - # use reinmax for better second-order accuracy - https://arxiv.org/abs/2304.08612 + # use reinmax for better second-order accuracy - https://huggingface.co/papers/2304.08612 # algorithm 2 if reinmax: @@ -1156,7 +1156,7 @@ def batched_embedding(indices, embeds): def orthogonal_loss_fn(t): - # eq (2) from https://arxiv.org/abs/2112.00384 + # eq (2) from https://huggingface.co/papers/2112.00384 h, n = t.shape[:2] normed_codes = F.normalize(t, p=2, dim=-1) cosine_sim = einsum("h i d, h j d -> h i j", normed_codes, normed_codes) diff --git a/lerobot/common/robot_devices/cameras/configs.py b/lerobot/common/robot_devices/cameras/configs.py deleted file mode 100644 index 013419a9e..000000000 --- a/lerobot/common/robot_devices/cameras/configs.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -from dataclasses import dataclass - -import draccus - - -@dataclass -class CameraConfig(draccus.ChoiceRegistry, abc.ABC): - @property - def type(self) -> str: - return self.get_choice_name(self.__class__) - - -@CameraConfig.register_subclass("opencv") -@dataclass -class OpenCVCameraConfig(CameraConfig): - """ - Example of tested options for Intel Real Sense D405: - - ```python - OpenCVCameraConfig(0, 30, 640, 480) - OpenCVCameraConfig(0, 60, 640, 480) - OpenCVCameraConfig(0, 90, 640, 480) - OpenCVCameraConfig(0, 30, 1280, 720) - ``` - """ - - camera_index: int - fps: int | None = None - width: int | None = None - height: int | None = None - color_mode: str = "rgb" - channels: int | None = None - rotation: int | None = None - mock: bool = False - - def __post_init__(self): - if self.color_mode not in ["rgb", "bgr"]: - raise ValueError( - f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided." - ) - - self.channels = 3 - - if self.rotation not in [-90, None, 90, 180]: - raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})") - - -@CameraConfig.register_subclass("intelrealsense") -@dataclass -class IntelRealSenseCameraConfig(CameraConfig): - """ - Example of tested options for Intel Real Sense D405: - - ```python - IntelRealSenseCameraConfig(128422271347, 30, 640, 480) - IntelRealSenseCameraConfig(128422271347, 60, 640, 480) - IntelRealSenseCameraConfig(128422271347, 90, 640, 480) - IntelRealSenseCameraConfig(128422271347, 30, 1280, 720) - IntelRealSenseCameraConfig(128422271347, 30, 640, 480, use_depth=True) - IntelRealSenseCameraConfig(128422271347, 30, 640, 480, rotation=90) - ``` - """ - - name: str | None = None - serial_number: int | None = None - fps: int | None = None - width: int | None = None - height: int | None = None - color_mode: str = "rgb" - channels: int | None = None - use_depth: bool = False - force_hardware_reset: bool = True - rotation: int | None = None - mock: bool = False - - def __post_init__(self): - # bool is stronger than is None, since it works with empty strings - if bool(self.name) and bool(self.serial_number): - raise ValueError( - f"One of them must be set: name or serial_number, but {self.name=} and {self.serial_number=} provided." - ) - - if self.color_mode not in ["rgb", "bgr"]: - raise ValueError( - f"`color_mode` is expected to be 'rgb' or 'bgr', but {self.color_mode} is provided." - ) - - self.channels = 3 - - at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None - at_least_one_is_none = self.fps is None or self.width is None or self.height is None - if at_least_one_is_not_none and at_least_one_is_none: - raise ValueError( - "For `fps`, `width` and `height`, either all of them need to be set, or none of them, " - f"but {self.fps=}, {self.width=}, {self.height=} were provided." - ) - - if self.rotation not in [-90, None, 90, 180]: - raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})") diff --git a/lerobot/common/robot_devices/cameras/intelrealsense.py b/lerobot/common/robot_devices/cameras/intelrealsense.py deleted file mode 100644 index 7a21661a8..000000000 --- a/lerobot/common/robot_devices/cameras/intelrealsense.py +++ /dev/null @@ -1,538 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file contains utilities for recording frames from Intel Realsense cameras. -""" - -import argparse -import concurrent.futures -import logging -import math -import shutil -import threading -import time -import traceback -from collections import Counter -from pathlib import Path -from threading import Thread - -import numpy as np -from PIL import Image - -from lerobot.common.robot_devices.cameras.configs import IntelRealSenseCameraConfig -from lerobot.common.robot_devices.utils import ( - RobotDeviceAlreadyConnectedError, - RobotDeviceNotConnectedError, - busy_wait, -) -from lerobot.common.utils.utils import capture_timestamp_utc - -SERIAL_NUMBER_INDEX = 1 - - -def find_cameras(raise_when_empty=True, mock=False) -> list[dict]: - """ - Find the names and the serial numbers of the Intel RealSense cameras - connected to the computer. - """ - if mock: - import tests.cameras.mock_pyrealsense2 as rs - else: - import pyrealsense2 as rs - - cameras = [] - for device in rs.context().query_devices(): - serial_number = int(device.get_info(rs.camera_info(SERIAL_NUMBER_INDEX))) - name = device.get_info(rs.camera_info.name) - cameras.append( - { - "serial_number": serial_number, - "name": name, - } - ) - - if raise_when_empty and len(cameras) == 0: - raise OSError( - "Not a single camera was detected. Try re-plugging, or re-installing `librealsense` and its python wrapper `pyrealsense2`, or updating the firmware." - ) - - return cameras - - -def save_image(img_array, serial_number, frame_index, images_dir): - try: - img = Image.fromarray(img_array) - path = images_dir / f"camera_{serial_number}_frame_{frame_index:06d}.png" - path.parent.mkdir(parents=True, exist_ok=True) - img.save(str(path), quality=100) - logging.info(f"Saved image: {path}") - except Exception as e: - logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}") - - -def save_images_from_cameras( - images_dir: Path, - serial_numbers: list[int] | None = None, - fps=None, - width=None, - height=None, - record_time_s=2, - mock=False, -): - """ - Initializes all the cameras and saves images to the directory. Useful to visually identify the camera - associated to a given serial number. - """ - if serial_numbers is None or len(serial_numbers) == 0: - camera_infos = find_cameras(mock=mock) - serial_numbers = [cam["serial_number"] for cam in camera_infos] - - if mock: - import tests.cameras.mock_cv2 as cv2 - else: - import cv2 - - print("Connecting cameras") - cameras = [] - for cam_sn in serial_numbers: - print(f"{cam_sn=}") - config = IntelRealSenseCameraConfig( - serial_number=cam_sn, fps=fps, width=width, height=height, mock=mock - ) - camera = IntelRealSenseCamera(config) - camera.connect() - print( - f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.capture_width}, height={camera.capture_height}, color_mode={camera.color_mode})" - ) - cameras.append(camera) - - images_dir = Path(images_dir) - if images_dir.exists(): - shutil.rmtree( - images_dir, - ) - images_dir.mkdir(parents=True, exist_ok=True) - - print(f"Saving images to {images_dir}") - frame_index = 0 - start_time = time.perf_counter() - try: - with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: - while True: - now = time.perf_counter() - - for camera in cameras: - # If we use async_read when fps is None, the loop will go full speed, and we will end up - # saving the same images from the cameras multiple times until the RAM/disk is full. - image = camera.read() if fps is None else camera.async_read() - if image is None: - print("No Frame") - - bgr_converted_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) - - executor.submit( - save_image, - bgr_converted_image, - camera.serial_number, - frame_index, - images_dir, - ) - - if fps is not None: - dt_s = time.perf_counter() - now - busy_wait(1 / fps - dt_s) - - if time.perf_counter() - start_time > record_time_s: - break - - print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}") - - frame_index += 1 - finally: - print(f"Images have been saved to {images_dir}") - for camera in cameras: - camera.disconnect() - - -class IntelRealSenseCamera: - """ - The IntelRealSenseCamera class is similar to OpenCVCamera class but adds additional features for Intel Real Sense cameras: - - is instantiated with the serial number of the camera - won't randomly change as it can be the case of OpenCVCamera for Linux, - - can also be instantiated with the camera's name — if it's unique — using IntelRealSenseCamera.init_from_name(), - - depth map can be returned. - - To find the camera indices of your cameras, you can run our utility script that will save a few frames for each camera: - ```bash - python lerobot/common/robot_devices/cameras/intelrealsense.py --images-dir outputs/images_from_intelrealsense_cameras - ``` - - When an IntelRealSenseCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode - of the given camera will be used. - - Example of instantiating with a serial number: - ```python - from lerobot.common.robot_devices.cameras.configs import IntelRealSenseCameraConfig - - config = IntelRealSenseCameraConfig(serial_number=128422271347) - camera = IntelRealSenseCamera(config) - camera.connect() - color_image = camera.read() - # when done using the camera, consider disconnecting - camera.disconnect() - ``` - - Example of instantiating with a name if it's unique: - ``` - config = IntelRealSenseCameraConfig(name="Intel RealSense D405") - ``` - - Example of changing default fps, width, height and color_mode: - ```python - config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=30, width=1280, height=720) - config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480) - config = IntelRealSenseCameraConfig(serial_number=128422271347, fps=90, width=640, height=480, color_mode="bgr") - # Note: might error out upon `camera.connect()` if these settings are not compatible with the camera - ``` - - Example of returning depth: - ```python - config = IntelRealSenseCameraConfig(serial_number=128422271347, use_depth=True) - camera = IntelRealSenseCamera(config) - camera.connect() - color_image, depth_map = camera.read() - ``` - """ - - def __init__( - self, - config: IntelRealSenseCameraConfig, - ): - self.config = config - if config.name is not None: - self.serial_number = self.find_serial_number_from_name(config.name) - else: - self.serial_number = config.serial_number - - # Store the raw (capture) resolution from the config. - self.capture_width = config.width - self.capture_height = config.height - - # If rotated by ±90, swap width and height. - if config.rotation in [-90, 90]: - self.width = config.height - self.height = config.width - else: - self.width = config.width - self.height = config.height - - self.fps = config.fps - self.channels = config.channels - self.color_mode = config.color_mode - self.use_depth = config.use_depth - self.force_hardware_reset = config.force_hardware_reset - self.mock = config.mock - - self.camera = None - self.is_connected = False - self.thread = None - self.stop_event = None - self.color_image = None - self.depth_map = None - self.logs = {} - - if self.mock: - import tests.cameras.mock_cv2 as cv2 - else: - import cv2 - - self.rotation = None - if config.rotation == -90: - self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE - elif config.rotation == 90: - self.rotation = cv2.ROTATE_90_CLOCKWISE - elif config.rotation == 180: - self.rotation = cv2.ROTATE_180 - - def find_serial_number_from_name(self, name): - camera_infos = find_cameras() - camera_names = [cam["name"] for cam in camera_infos] - this_name_count = Counter(camera_names)[name] - if this_name_count > 1: - # TODO(aliberts): Test this with multiple identical cameras (Aloha) - raise ValueError( - f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them." - ) - - name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos} - cam_sn = name_to_serial_dict[name] - - return cam_sn - - def connect(self): - if self.is_connected: - raise RobotDeviceAlreadyConnectedError( - f"IntelRealSenseCamera({self.serial_number}) is already connected." - ) - - if self.mock: - import tests.cameras.mock_pyrealsense2 as rs - else: - import pyrealsense2 as rs - - config = rs.config() - config.enable_device(str(self.serial_number)) - - if self.fps and self.capture_width and self.capture_height: - # TODO(rcadene): can we set rgb8 directly? - config.enable_stream( - rs.stream.color, self.capture_width, self.capture_height, rs.format.rgb8, self.fps - ) - else: - config.enable_stream(rs.stream.color) - - if self.use_depth: - if self.fps and self.capture_width and self.capture_height: - config.enable_stream( - rs.stream.depth, self.capture_width, self.capture_height, rs.format.z16, self.fps - ) - else: - config.enable_stream(rs.stream.depth) - - self.camera = rs.pipeline() - try: - profile = self.camera.start(config) - is_camera_open = True - except RuntimeError: - is_camera_open = False - traceback.print_exc() - - # If the camera doesn't work, display the camera indices corresponding to - # valid cameras. - if not is_camera_open: - # Verify that the provided `serial_number` is valid before printing the traceback - camera_infos = find_cameras() - serial_numbers = [cam["serial_number"] for cam in camera_infos] - if self.serial_number not in serial_numbers: - raise ValueError( - f"`serial_number` is expected to be one of these available cameras {serial_numbers}, but {self.serial_number} is provided instead. " - "To find the serial number you should use, run `python lerobot/common/robot_devices/cameras/intelrealsense.py`." - ) - - raise OSError(f"Can't access IntelRealSenseCamera({self.serial_number}).") - - color_stream = profile.get_stream(rs.stream.color) - color_profile = color_stream.as_video_stream_profile() - actual_fps = color_profile.fps() - actual_width = color_profile.width() - actual_height = color_profile.height() - - # Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30) - if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3): - # Using `OSError` since it's a broad that encompasses issues related to device communication - raise OSError( - f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}." - ) - if self.capture_width is not None and self.capture_width != actual_width: - raise OSError( - f"Can't set {self.capture_width=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_width}." - ) - if self.capture_height is not None and self.capture_height != actual_height: - raise OSError( - f"Can't set {self.capture_height=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_height}." - ) - - self.fps = round(actual_fps) - self.capture_width = round(actual_width) - self.capture_height = round(actual_height) - - self.is_connected = True - - def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]: - """Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3) - of type `np.uint8`, contrarily to the pytorch format which is float channel first. - - When `use_depth=True`, returns a tuple `(color_image, depth_map)` with a depth map in the format - height x width (e.g. 480 x 640) of type np.uint16. - - Note: Reading a frame is done every `camera.fps` times per second, and it is blocking. - If you are reading data from other sensors, we advise to use `camera.async_read()` which is non blocking version of `camera.read()`. - """ - if not self.is_connected: - raise RobotDeviceNotConnectedError( - f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first." - ) - - if self.mock: - import tests.cameras.mock_cv2 as cv2 - else: - import cv2 - - start_time = time.perf_counter() - - frame = self.camera.wait_for_frames(timeout_ms=5000) - - color_frame = frame.get_color_frame() - - if not color_frame: - raise OSError(f"Can't capture color image from IntelRealSenseCamera({self.serial_number}).") - - color_image = np.asanyarray(color_frame.get_data()) - - requested_color_mode = self.color_mode if temporary_color is None else temporary_color - if requested_color_mode not in ["rgb", "bgr"]: - raise ValueError( - f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided." - ) - - # IntelRealSense uses RGB format as default (red, green, blue). - if requested_color_mode == "bgr": - color_image = cv2.cvtColor(color_image, cv2.COLOR_RGB2BGR) - - h, w, _ = color_image.shape - if h != self.capture_height or w != self.capture_width: - raise OSError( - f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead." - ) - - if self.rotation is not None: - color_image = cv2.rotate(color_image, self.rotation) - - # log the number of seconds it took to read the image - self.logs["delta_timestamp_s"] = time.perf_counter() - start_time - - # log the utc time at which the image was received - self.logs["timestamp_utc"] = capture_timestamp_utc() - - if self.use_depth: - depth_frame = frame.get_depth_frame() - if not depth_frame: - raise OSError(f"Can't capture depth image from IntelRealSenseCamera({self.serial_number}).") - - depth_map = np.asanyarray(depth_frame.get_data()) - - h, w = depth_map.shape - if h != self.capture_height or w != self.capture_width: - raise OSError( - f"Can't capture depth map with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead." - ) - - if self.rotation is not None: - depth_map = cv2.rotate(depth_map, self.rotation) - - return color_image, depth_map - else: - return color_image - - def read_loop(self): - while not self.stop_event.is_set(): - if self.use_depth: - self.color_image, self.depth_map = self.read() - else: - self.color_image = self.read() - - def async_read(self): - """Access the latest color image""" - if not self.is_connected: - raise RobotDeviceNotConnectedError( - f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first." - ) - - if self.thread is None: - self.stop_event = threading.Event() - self.thread = Thread(target=self.read_loop, args=()) - self.thread.daemon = True - self.thread.start() - - num_tries = 0 - while self.color_image is None: - # TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here - num_tries += 1 - time.sleep(1 / self.fps) - if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()): - raise Exception( - "The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called." - ) - - if self.use_depth: - return self.color_image, self.depth_map - else: - return self.color_image - - def disconnect(self): - if not self.is_connected: - raise RobotDeviceNotConnectedError( - f"IntelRealSenseCamera({self.serial_number}) is not connected. Try running `camera.connect()` first." - ) - - if self.thread is not None and self.thread.is_alive(): - # wait for the thread to finish - self.stop_event.set() - self.thread.join() - self.thread = None - self.stop_event = None - - self.camera.stop() - self.camera = None - - self.is_connected = False - - def __del__(self): - if getattr(self, "is_connected", False): - self.disconnect() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Save a few frames using `IntelRealSenseCamera` for all cameras connected to the computer, or a selected subset." - ) - parser.add_argument( - "--serial-numbers", - type=int, - nargs="*", - default=None, - help="List of serial numbers used to instantiate the `IntelRealSenseCamera`. If not provided, find and use all available camera indices.", - ) - parser.add_argument( - "--fps", - type=int, - default=30, - help="Set the number of frames recorded per seconds for all cameras. If not provided, use the default fps of each camera.", - ) - parser.add_argument( - "--width", - type=str, - default=640, - help="Set the width for all cameras. If not provided, use the default width of each camera.", - ) - parser.add_argument( - "--height", - type=str, - default=480, - help="Set the height for all cameras. If not provided, use the default height of each camera.", - ) - parser.add_argument( - "--images-dir", - type=Path, - default="outputs/images_from_intelrealsense_cameras", - help="Set directory to save a few frames for each camera.", - ) - parser.add_argument( - "--record-time-s", - type=float, - default=2.0, - help="Set the number of seconds used to record the frames. By default, 2 seconds.", - ) - args = parser.parse_args() - save_images_from_cameras(**vars(args)) diff --git a/lerobot/common/robot_devices/cameras/opencv.py b/lerobot/common/robot_devices/cameras/opencv.py deleted file mode 100644 index f279f3158..000000000 --- a/lerobot/common/robot_devices/cameras/opencv.py +++ /dev/null @@ -1,518 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file contains utilities for recording frames from cameras. For more info look at `OpenCVCamera` docstring. -""" - -import argparse -import concurrent.futures -import math -import platform -import shutil -import threading -import time -from pathlib import Path -from threading import Thread - -import numpy as np -from PIL import Image - -from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig -from lerobot.common.robot_devices.utils import ( - RobotDeviceAlreadyConnectedError, - RobotDeviceNotConnectedError, - busy_wait, -) -from lerobot.common.utils.utils import capture_timestamp_utc - -# The maximum opencv device index depends on your operating system. For instance, -# if you have 3 cameras, they should be associated to index 0, 1, and 2. This is the case -# on MacOS. However, on Ubuntu, the indices are different like 6, 16, 23. -# When you change the USB port or reboot the computer, the operating system might -# treat the same cameras as new devices. Thus we select a higher bound to search indices. -MAX_OPENCV_INDEX = 60 - - -def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]: - cameras = [] - if platform.system() == "Linux": - print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports") - possible_ports = [str(port) for port in Path("/dev").glob("video*")] - ports = _find_cameras(possible_ports, mock=mock) - for port in ports: - cameras.append( - { - "port": port, - "index": int(port.removeprefix("/dev/video")), - } - ) - else: - print( - "Mac or Windows detected. Finding available camera indices through " - f"scanning all indices from 0 to {MAX_OPENCV_INDEX}" - ) - possible_indices = range(max_index_search_range) - indices = _find_cameras(possible_indices, mock=mock) - for index in indices: - cameras.append( - { - "port": None, - "index": index, - } - ) - - return cameras - - -def _find_cameras( - possible_camera_ids: list[int | str], raise_when_empty=False, mock=False -) -> list[int | str]: - if mock: - import tests.cameras.mock_cv2 as cv2 - else: - import cv2 - - camera_ids = [] - for camera_idx in possible_camera_ids: - camera = cv2.VideoCapture(camera_idx) - is_open = camera.isOpened() - camera.release() - - if is_open: - print(f"Camera found at index {camera_idx}") - camera_ids.append(camera_idx) - - if raise_when_empty and len(camera_ids) == 0: - raise OSError( - "Not a single camera was detected. Try re-plugging, or re-installing `opencv2`, " - "or your camera driver, or make sure your camera is compatible with opencv2." - ) - - return camera_ids - - -def is_valid_unix_path(path: str) -> bool: - """Note: if 'path' points to a symlink, this will return True only if the target exists""" - p = Path(path) - return p.is_absolute() and p.exists() - - -def get_camera_index_from_unix_port(port: Path) -> int: - return int(str(port.resolve()).removeprefix("/dev/video")) - - -def save_image(img_array, camera_index, frame_index, images_dir): - img = Image.fromarray(img_array) - path = images_dir / f"camera_{camera_index:02d}_frame_{frame_index:06d}.png" - path.parent.mkdir(parents=True, exist_ok=True) - img.save(str(path), quality=100) - - -def save_images_from_cameras( - images_dir: Path, - camera_ids: list | None = None, - fps=None, - width=None, - height=None, - record_time_s=2, - mock=False, -): - """ - Initializes all the cameras and saves images to the directory. Useful to visually identify the camera - associated to a given camera index. - """ - if camera_ids is None or len(camera_ids) == 0: - camera_infos = find_cameras(mock=mock) - camera_ids = [cam["index"] for cam in camera_infos] - - print("Connecting cameras") - cameras = [] - for cam_idx in camera_ids: - config = OpenCVCameraConfig(camera_index=cam_idx, fps=fps, width=width, height=height, mock=mock) - camera = OpenCVCamera(config) - camera.connect() - print( - f"OpenCVCamera({camera.camera_index}, fps={camera.fps}, width={camera.capture_width}, " - f"height={camera.capture_height}, color_mode={camera.color_mode})" - ) - cameras.append(camera) - - images_dir = Path(images_dir) - if images_dir.exists(): - shutil.rmtree( - images_dir, - ) - images_dir.mkdir(parents=True, exist_ok=True) - - print(f"Saving images to {images_dir}") - frame_index = 0 - start_time = time.perf_counter() - with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: - while True: - now = time.perf_counter() - - for camera in cameras: - # If we use async_read when fps is None, the loop will go full speed, and we will endup - # saving the same images from the cameras multiple times until the RAM/disk is full. - image = camera.read() if fps is None else camera.async_read() - - executor.submit( - save_image, - image, - camera.camera_index, - frame_index, - images_dir, - ) - - if fps is not None: - dt_s = time.perf_counter() - now - busy_wait(1 / fps - dt_s) - - print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}") - - if time.perf_counter() - start_time > record_time_s: - break - - frame_index += 1 - - print(f"Images have been saved to {images_dir}") - - -class OpenCVCamera: - """ - The OpenCVCamera class allows to efficiently record images from cameras. It relies on opencv2 to communicate - with the cameras. Most cameras are compatible. For more info, see the [Video I/O with OpenCV Overview](https://docs.opencv.org/4.x/d0/da7/videoio_overview.html). - - An OpenCVCamera instance requires a camera index (e.g. `OpenCVCamera(camera_index=0)`). When you only have one camera - like a webcam of a laptop, the camera index is expected to be 0, but it might also be very different, and the camera index - might change if you reboot your computer or re-plug your camera. This behavior depends on your operation system. - - To find the camera indices of your cameras, you can run our utility script that will be save a few frames for each camera: - ```bash - python lerobot/common/robot_devices/cameras/opencv.py --images-dir outputs/images_from_opencv_cameras - ``` - - When an OpenCVCamera is instantiated, if no specific config is provided, the default fps, width, height and color_mode - of the given camera will be used. - - Example of usage: - ```python - from lerobot.common.robot_devices.cameras.configs import OpenCVCameraConfig - - config = OpenCVCameraConfig(camera_index=0) - camera = OpenCVCamera(config) - camera.connect() - color_image = camera.read() - # when done using the camera, consider disconnecting - camera.disconnect() - ``` - - Example of changing default fps, width, height and color_mode: - ```python - config = OpenCVCameraConfig(camera_index=0, fps=30, width=1280, height=720) - config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480) - config = OpenCVCameraConfig(camera_index=0, fps=90, width=640, height=480, color_mode="bgr") - # Note: might error out open `camera.connect()` if these settings are not compatible with the camera - ``` - """ - - def __init__(self, config: OpenCVCameraConfig): - self.config = config - self.camera_index = config.camera_index - self.port = None - - # Linux uses ports for connecting to cameras - if platform.system() == "Linux": - if isinstance(self.camera_index, int): - self.port = Path(f"/dev/video{self.camera_index}") - elif isinstance(self.camera_index, str) and is_valid_unix_path(self.camera_index): - self.port = Path(self.camera_index) - # Retrieve the camera index from a potentially symlinked path - self.camera_index = get_camera_index_from_unix_port(self.port) - else: - raise ValueError(f"Please check the provided camera_index: {self.camera_index}") - - # Store the raw (capture) resolution from the config. - self.capture_width = config.width - self.capture_height = config.height - - # If rotated by ±90, swap width and height. - if config.rotation in [-90, 90]: - self.width = config.height - self.height = config.width - else: - self.width = config.width - self.height = config.height - - self.fps = config.fps - self.channels = config.channels - self.color_mode = config.color_mode - self.mock = config.mock - - self.camera = None - self.is_connected = False - self.thread = None - self.stop_event = None - self.color_image = None - self.logs = {} - - if self.mock: - import tests.cameras.mock_cv2 as cv2 - else: - import cv2 - - self.rotation = None - if config.rotation == -90: - self.rotation = cv2.ROTATE_90_COUNTERCLOCKWISE - elif config.rotation == 90: - self.rotation = cv2.ROTATE_90_CLOCKWISE - elif config.rotation == 180: - self.rotation = cv2.ROTATE_180 - - def connect(self): - if self.is_connected: - raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.") - - if self.mock: - import tests.cameras.mock_cv2 as cv2 - else: - import cv2 - - # Use 1 thread to avoid blocking the main thread. Especially useful during data collection - # when other threads are used to save the images. - cv2.setNumThreads(1) - - backend = ( - cv2.CAP_V4L2 - if platform.system() == "Linux" - else cv2.CAP_DSHOW - if platform.system() == "Windows" - else cv2.CAP_AVFOUNDATION - if platform.system() == "Darwin" - else cv2.CAP_ANY - ) - - camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index - # First create a temporary camera trying to access `camera_index`, - # and verify it is a valid camera by calling `isOpened`. - tmp_camera = cv2.VideoCapture(camera_idx, backend) - is_camera_open = tmp_camera.isOpened() - # Release camera to make it accessible for `find_camera_indices` - tmp_camera.release() - del tmp_camera - - # If the camera doesn't work, display the camera indices corresponding to - # valid cameras. - if not is_camera_open: - # Verify that the provided `camera_index` is valid before printing the traceback - cameras_info = find_cameras() - available_cam_ids = [cam["index"] for cam in cameras_info] - if self.camera_index not in available_cam_ids: - raise ValueError( - f"`camera_index` is expected to be one of these available cameras {available_cam_ids}, but {self.camera_index} is provided instead. " - "To find the camera index you should use, run `python lerobot/common/robot_devices/cameras/opencv.py`." - ) - - raise OSError(f"Can't access OpenCVCamera({camera_idx}).") - - # Secondly, create the camera that will be used downstream. - # Note: For some unknown reason, calling `isOpened` blocks the camera which then - # needs to be re-created. - self.camera = cv2.VideoCapture(camera_idx, backend) - - if self.fps is not None: - self.camera.set(cv2.CAP_PROP_FPS, self.fps) - if self.capture_width is not None: - self.camera.set(cv2.CAP_PROP_FRAME_WIDTH, self.capture_width) - if self.capture_height is not None: - self.camera.set(cv2.CAP_PROP_FRAME_HEIGHT, self.capture_height) - - actual_fps = self.camera.get(cv2.CAP_PROP_FPS) - actual_width = self.camera.get(cv2.CAP_PROP_FRAME_WIDTH) - actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT) - - # Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30) - if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3): - # Using `OSError` since it's a broad that encompasses issues related to device communication - raise OSError( - f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}." - ) - if self.capture_width is not None and not math.isclose( - self.capture_width, actual_width, rel_tol=1e-3 - ): - raise OSError( - f"Can't set {self.capture_width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}." - ) - if self.capture_height is not None and not math.isclose( - self.capture_height, actual_height, rel_tol=1e-3 - ): - raise OSError( - f"Can't set {self.capture_height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}." - ) - - self.fps = round(actual_fps) - self.capture_width = round(actual_width) - self.capture_height = round(actual_height) - self.is_connected = True - - def read(self, temporary_color_mode: str | None = None) -> np.ndarray: - """Read a frame from the camera returned in the format (height, width, channels) - (e.g. 480 x 640 x 3), contrarily to the pytorch format which is channel first. - - Note: Reading a frame is done every `camera.fps` times per second, and it is blocking. - If you are reading data from other sensors, we advise to use `camera.async_read()` which is non blocking version of `camera.read()`. - """ - if not self.is_connected: - raise RobotDeviceNotConnectedError( - f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first." - ) - - start_time = time.perf_counter() - - ret, color_image = self.camera.read() - - if not ret: - raise OSError(f"Can't capture color image from camera {self.camera_index}.") - - requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode - - if requested_color_mode not in ["rgb", "bgr"]: - raise ValueError( - f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided." - ) - - # OpenCV uses BGR format as default (blue, green, red) for all operations, including displaying images. - # However, Deep Learning framework such as LeRobot uses RGB format as default to train neural networks, - # so we convert the image color from BGR to RGB. - if requested_color_mode == "rgb": - if self.mock: - import tests.cameras.mock_cv2 as cv2 - else: - import cv2 - - color_image = cv2.cvtColor(color_image, cv2.COLOR_BGR2RGB) - - h, w, _ = color_image.shape - if h != self.capture_height or w != self.capture_width: - raise OSError( - f"Can't capture color image with expected height and width ({self.height} x {self.width}). ({h} x {w}) returned instead." - ) - - if self.rotation is not None: - color_image = cv2.rotate(color_image, self.rotation) - - # log the number of seconds it took to read the image - self.logs["delta_timestamp_s"] = time.perf_counter() - start_time - - # log the utc time at which the image was received - self.logs["timestamp_utc"] = capture_timestamp_utc() - - self.color_image = color_image - - return color_image - - def read_loop(self): - while not self.stop_event.is_set(): - try: - self.color_image = self.read() - except Exception as e: - print(f"Error reading in thread: {e}") - - def async_read(self): - if not self.is_connected: - raise RobotDeviceNotConnectedError( - f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first." - ) - - if self.thread is None: - self.stop_event = threading.Event() - self.thread = Thread(target=self.read_loop, args=()) - self.thread.daemon = True - self.thread.start() - - num_tries = 0 - while True: - if self.color_image is not None: - return self.color_image - - time.sleep(1 / self.fps) - num_tries += 1 - if num_tries > self.fps * 2: - raise TimeoutError("Timed out waiting for async_read() to start.") - - def disconnect(self): - if not self.is_connected: - raise RobotDeviceNotConnectedError( - f"OpenCVCamera({self.camera_index}) is not connected. Try running `camera.connect()` first." - ) - - if self.thread is not None: - self.stop_event.set() - self.thread.join() # wait for the thread to finish - self.thread = None - self.stop_event = None - - self.camera.release() - self.camera = None - self.is_connected = False - - def __del__(self): - if getattr(self, "is_connected", False): - self.disconnect() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - description="Save a few frames using `OpenCVCamera` for all cameras connected to the computer, or a selected subset." - ) - parser.add_argument( - "--camera-ids", - type=int, - nargs="*", - default=None, - help="List of camera indices used to instantiate the `OpenCVCamera`. If not provided, find and use all available camera indices.", - ) - parser.add_argument( - "--fps", - type=int, - default=None, - help="Set the number of frames recorded per seconds for all cameras. If not provided, use the default fps of each camera.", - ) - parser.add_argument( - "--width", - type=str, - default=None, - help="Set the width for all cameras. If not provided, use the default width of each camera.", - ) - parser.add_argument( - "--height", - type=str, - default=None, - help="Set the height for all cameras. If not provided, use the default height of each camera.", - ) - parser.add_argument( - "--images-dir", - type=Path, - default="outputs/images_from_opencv_cameras", - help="Set directory to save a few frames for each camera.", - ) - parser.add_argument( - "--record-time-s", - type=float, - default=4.0, - help="Set the number of seconds used to record the frames. By default, 2 seconds.", - ) - args = parser.parse_args() - save_images_from_cameras(**vars(args)) diff --git a/lerobot/common/robot_devices/cameras/utils.py b/lerobot/common/robot_devices/cameras/utils.py deleted file mode 100644 index c64316467..000000000 --- a/lerobot/common/robot_devices/cameras/utils.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Protocol - -import numpy as np - -from lerobot.common.robot_devices.cameras.configs import ( - CameraConfig, - IntelRealSenseCameraConfig, - OpenCVCameraConfig, -) - - -# Defines a camera type -class Camera(Protocol): - def connect(self): ... - def read(self, temporary_color: str | None = None) -> np.ndarray: ... - def async_read(self) -> np.ndarray: ... - def disconnect(self): ... - - -def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> list[Camera]: - cameras = {} - - for key, cfg in camera_configs.items(): - if cfg.type == "opencv": - from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera - - cameras[key] = OpenCVCamera(cfg) - - elif cfg.type == "intelrealsense": - from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera - - cameras[key] = IntelRealSenseCamera(cfg) - else: - raise ValueError(f"The camera type '{cfg.type}' is not valid.") - - return cameras - - -def make_camera(camera_type, **kwargs) -> Camera: - if camera_type == "opencv": - from lerobot.common.robot_devices.cameras.opencv import OpenCVCamera - - config = OpenCVCameraConfig(**kwargs) - return OpenCVCamera(config) - - elif camera_type == "intelrealsense": - from lerobot.common.robot_devices.cameras.intelrealsense import IntelRealSenseCamera - - config = IntelRealSenseCameraConfig(**kwargs) - return IntelRealSenseCamera(config) - - else: - raise ValueError(f"The camera type '{camera_type}' is not valid.") diff --git a/lerobot/common/robot_devices/control_configs.py b/lerobot/common/robot_devices/control_configs.py deleted file mode 100644 index cb558c716..000000000 --- a/lerobot/common/robot_devices/control_configs.py +++ /dev/null @@ -1,134 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from pathlib import Path - -import draccus - -from lerobot.common.robot_devices.robots.configs import RobotConfig -from lerobot.configs import parser -from lerobot.configs.policies import PreTrainedConfig - - -@dataclass -class ControlConfig(draccus.ChoiceRegistry): - pass - - -@ControlConfig.register_subclass("calibrate") -@dataclass -class CalibrateControlConfig(ControlConfig): - # List of arms to calibrate (e.g. `--arms='["left_follower","right_follower"]' left_leader`) - arms: list[str] | None = None - - -@ControlConfig.register_subclass("teleoperate") -@dataclass -class TeleoperateControlConfig(ControlConfig): - # Limit the maximum frames per second. By default, no limit. - fps: int | None = None - teleop_time_s: float | None = None - # Display all cameras on screen - display_data: bool = False - - -@ControlConfig.register_subclass("record") -@dataclass -class RecordControlConfig(ControlConfig): - # Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`). - repo_id: str - # A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.") - single_task: str - # Root directory where the dataset will be stored (e.g. 'dataset/path'). - root: str | Path | None = None - policy: PreTrainedConfig | None = None - # Limit the frames per second. By default, uses the policy fps. - fps: int | None = None - # Number of seconds before starting data collection. It allows the robot devices to warmup and synchronize. - warmup_time_s: int | float = 10 - # Number of seconds for data recording for each episode. - episode_time_s: int | float = 60 - # Number of seconds for resetting the environment after each episode. - reset_time_s: int | float = 60 - # Number of episodes to record. - num_episodes: int = 50 - # Encode frames in the dataset into video - video: bool = True - # Upload dataset to Hugging Face hub. - push_to_hub: bool = True - # Upload on private repository on the Hugging Face hub. - private: bool = False - # Add tags to your dataset on the hub. - tags: list[str] | None = None - # Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only; - # set to ≥1 to use subprocesses, each using threads to write images. The best number of processes - # and threads depends on your system. We recommend 4 threads per camera with 0 processes. - # If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses. - num_image_writer_processes: int = 0 - # Number of threads writing the frames as png images on disk, per camera. - # Too many threads might cause unstable teleoperation fps due to main thread being blocked. - # Not enough threads might cause low camera fps. - num_image_writer_threads_per_camera: int = 4 - # Display all cameras on screen - display_data: bool = False - # Use vocal synthesis to read events. - play_sounds: bool = True - # Resume recording on an existing dataset. - resume: bool = False - - def __post_init__(self): - # HACK: We parse again the cli args here to get the pretrained path if there was one. - policy_path = parser.get_path_arg("control.policy") - if policy_path: - cli_overrides = parser.get_cli_overrides("control.policy") - self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) - self.policy.pretrained_path = policy_path - - -@ControlConfig.register_subclass("replay") -@dataclass -class ReplayControlConfig(ControlConfig): - # Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`). - repo_id: str - # Index of the episode to replay. - episode: int - # Root directory where the dataset will be stored (e.g. 'dataset/path'). - root: str | Path | None = None - # Limit the frames per second. By default, uses the dataset fps. - fps: int | None = None - # Use vocal synthesis to read events. - play_sounds: bool = True - - -@ControlConfig.register_subclass("remote_robot") -@dataclass -class RemoteRobotConfig(ControlConfig): - log_interval: int = 100 - # Display all cameras on screen - display_data: bool = False - # Rerun configuration for remote robot (https://ref.rerun.io/docs/python/0.22.1/common/initialization_functions/#rerun.connect_tcp) - viewer_ip: str | None = None - viewer_port: str | None = None - - -@dataclass -class ControlPipelineConfig: - robot: RobotConfig - control: ControlConfig - - @classmethod - def __get_path_fields__(cls) -> list[str]: - """This enables the parser to load config from the policy using `--policy.path=local/dir`""" - return ["control.policy"] diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py deleted file mode 100644 index 4e42a9896..000000000 --- a/lerobot/common/robot_devices/control_utils.py +++ /dev/null @@ -1,347 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -######################################################################################## -# Utilities -######################################################################################## - - -import logging -import time -import traceback -from contextlib import nullcontext -from copy import copy -from functools import cache - -import rerun as rr -import torch -from deepdiff import DeepDiff -from termcolor import colored - -from lerobot.common.datasets.image_writer import safe_stop_image_writer -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.datasets.utils import get_features_from_robot -from lerobot.common.policies.pretrained import PreTrainedPolicy -from lerobot.common.robot_devices.robots.utils import Robot -from lerobot.common.robot_devices.utils import busy_wait -from lerobot.common.utils.utils import get_safe_torch_device, has_method - - -def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None): - log_items = [] - if episode_index is not None: - log_items.append(f"ep:{episode_index}") - if frame_index is not None: - log_items.append(f"frame:{frame_index}") - - def log_dt(shortname, dt_val_s): - nonlocal log_items, fps - info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)" - if fps is not None: - actual_fps = 1 / dt_val_s - if actual_fps < fps - 1: - info_str = colored(info_str, "yellow") - log_items.append(info_str) - - # total step time displayed in milliseconds and its frequency - log_dt("dt", dt_s) - - # TODO(aliberts): move robot-specific logs logic in robot.print_logs() - if not robot.robot_type.startswith("stretch"): - for name in robot.leader_arms: - key = f"read_leader_{name}_pos_dt_s" - if key in robot.logs: - log_dt("dtRlead", robot.logs[key]) - - for name in robot.follower_arms: - key = f"write_follower_{name}_goal_pos_dt_s" - if key in robot.logs: - log_dt("dtWfoll", robot.logs[key]) - - key = f"read_follower_{name}_pos_dt_s" - if key in robot.logs: - log_dt("dtRfoll", robot.logs[key]) - - for name in robot.cameras: - key = f"read_camera_{name}_dt_s" - if key in robot.logs: - log_dt(f"dtR{name}", robot.logs[key]) - - info_str = " ".join(log_items) - logging.info(info_str) - - -@cache -def is_headless(): - """Detects if python is running without a monitor.""" - try: - import pynput # noqa - - return False - except Exception: - print( - "Error trying to import pynput. Switching to headless mode. " - "As a result, the video stream from the cameras won't be shown, " - "and you won't be able to change the control flow with keyboards. " - "For more info, see traceback below.\n" - ) - traceback.print_exc() - print() - return True - - -def predict_action(observation, policy, device, use_amp): - observation = copy(observation) - with ( - torch.inference_mode(), - torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), - ): - # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension - for name in observation: - if "image" in name: - observation[name] = observation[name].type(torch.float32) / 255 - observation[name] = observation[name].permute(2, 0, 1).contiguous() - observation[name] = observation[name].unsqueeze(0) - observation[name] = observation[name].to(device) - - # Compute the next action with the policy - # based on the current observation - action = policy.select_action(observation) - - # Remove batch dimension - action = action.squeeze(0) - - # Move to cpu, if not already the case - action = action.to("cpu") - - return action - - -def init_keyboard_listener(): - # Allow to exit early while recording an episode or resetting the environment, - # by tapping the right arrow key '->'. This might require a sudo permission - # to allow your terminal to monitor keyboard events. - events = {} - events["exit_early"] = False - events["rerecord_episode"] = False - events["stop_recording"] = False - - if is_headless(): - logging.warning( - "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." - ) - listener = None - return listener, events - - # Only import pynput if not in a headless environment - from pynput import keyboard - - def on_press(key): - try: - if key == keyboard.Key.right: - print("Right arrow key pressed. Exiting loop...") - events["exit_early"] = True - elif key == keyboard.Key.left: - print("Left arrow key pressed. Exiting loop and rerecord the last episode...") - events["rerecord_episode"] = True - events["exit_early"] = True - elif key == keyboard.Key.esc: - print("Escape key pressed. Stopping data recording...") - events["stop_recording"] = True - events["exit_early"] = True - except Exception as e: - print(f"Error handling key press: {e}") - - listener = keyboard.Listener(on_press=on_press) - listener.start() - - return listener, events - - -def warmup_record( - robot, - events, - enable_teleoperation, - warmup_time_s, - display_data, - fps, -): - control_loop( - robot=robot, - control_time_s=warmup_time_s, - display_data=display_data, - events=events, - fps=fps, - teleoperate=enable_teleoperation, - ) - - -def record_episode( - robot, - dataset, - events, - episode_time_s, - display_data, - policy, - fps, - single_task, -): - control_loop( - robot=robot, - control_time_s=episode_time_s, - display_data=display_data, - dataset=dataset, - events=events, - policy=policy, - fps=fps, - teleoperate=policy is None, - single_task=single_task, - ) - - -@safe_stop_image_writer -def control_loop( - robot, - control_time_s=None, - teleoperate=False, - display_data=False, - dataset: LeRobotDataset | None = None, - events=None, - policy: PreTrainedPolicy = None, - fps: int | None = None, - single_task: str | None = None, -): - # TODO(rcadene): Add option to record logs - if not robot.is_connected: - robot.connect() - - if events is None: - events = {"exit_early": False} - - if control_time_s is None: - control_time_s = float("inf") - - if teleoperate and policy is not None: - raise ValueError("When `teleoperate` is True, `policy` should be None.") - - if dataset is not None and single_task is None: - raise ValueError("You need to provide a task as argument in `single_task`.") - - if dataset is not None and fps is not None and dataset.fps != fps: - raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") - - timestamp = 0 - start_episode_t = time.perf_counter() - while timestamp < control_time_s: - start_loop_t = time.perf_counter() - - if teleoperate: - observation, action = robot.teleop_step(record_data=True) - else: - observation = robot.capture_observation() - - if policy is not None: - pred_action = predict_action( - observation, policy, get_safe_torch_device(policy.config.device), policy.config.use_amp - ) - # Action can eventually be clipped using `max_relative_target`, - # so action actually sent is saved in the dataset. - action = robot.send_action(pred_action) - action = {"action": action} - - if dataset is not None: - frame = {**observation, **action, "task": single_task} - dataset.add_frame(frame) - - # TODO(Steven): This should be more general (for RemoteRobot instead of checking the name, but anyways it will change soon) - if (display_data and not is_headless()) or (display_data and robot.robot_type.startswith("lekiwi")): - for k, v in action.items(): - for i, vv in enumerate(v): - rr.log(f"sent_{k}_{i}", rr.Scalar(vv.numpy())) - - image_keys = [key for key in observation if "image" in key] - for key in image_keys: - rr.log(key, rr.Image(observation[key].numpy()), static=True) - - if fps is not None: - dt_s = time.perf_counter() - start_loop_t - busy_wait(1 / fps - dt_s) - - dt_s = time.perf_counter() - start_loop_t - log_control_info(robot, dt_s, fps=fps) - - timestamp = time.perf_counter() - start_episode_t - if events["exit_early"]: - events["exit_early"] = False - break - - -def reset_environment(robot, events, reset_time_s, fps): - # TODO(rcadene): refactor warmup_record and reset_environment - if has_method(robot, "teleop_safety_stop"): - robot.teleop_safety_stop() - - control_loop( - robot=robot, - control_time_s=reset_time_s, - events=events, - fps=fps, - teleoperate=True, - ) - - -def stop_recording(robot, listener, display_data): - robot.disconnect() - - if not is_headless() and listener is not None: - listener.stop() - - -def sanity_check_dataset_name(repo_id, policy_cfg): - _, dataset_name = repo_id.split("/") - # either repo_id doesnt start with "eval_" and there is no policy - # or repo_id starts with "eval_" and there is a policy - - # Check if dataset_name starts with "eval_" but policy is missing - if dataset_name.startswith("eval_") and policy_cfg is None: - raise ValueError( - f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})." - ) - - # Check if dataset_name does not start with "eval_" but policy is provided - if not dataset_name.startswith("eval_") and policy_cfg is not None: - raise ValueError( - f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy_cfg.type})." - ) - - -def sanity_check_dataset_robot_compatibility( - dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool -) -> None: - fields = [ - ("robot_type", dataset.meta.robot_type, robot.robot_type), - ("fps", dataset.fps, fps), - ("features", dataset.features, get_features_from_robot(robot, use_videos)), - ] - - mismatches = [] - for field, dataset_value, present_value in fields: - diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]) - if diff: - mismatches.append(f"{field}: expected {present_value}, got {dataset_value}") - - if mismatches: - raise ValueError( - "Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches) - ) diff --git a/lerobot/common/robot_devices/motors/dynamixel.py b/lerobot/common/robot_devices/motors/dynamixel.py deleted file mode 100644 index 6096ceb5d..000000000 --- a/lerobot/common/robot_devices/motors/dynamixel.py +++ /dev/null @@ -1,873 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import enum -import logging -import math -import time -import traceback -from copy import deepcopy - -import numpy as np -import tqdm - -from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig -from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError -from lerobot.common.utils.utils import capture_timestamp_utc - -PROTOCOL_VERSION = 2.0 -BAUDRATE = 1_000_000 -TIMEOUT_MS = 1000 - -MAX_ID_RANGE = 252 - -# The following bounds define the lower and upper joints range (after calibration). -# For joints in degree (i.e. revolute joints), their nominal range is [-180, 180] degrees -# which corresponds to a half rotation on the left and half rotation on the right. -# Some joints might require higher range, so we allow up to [-270, 270] degrees until -# an error is raised. -LOWER_BOUND_DEGREE = -270 -UPPER_BOUND_DEGREE = 270 -# For joints in percentage (i.e. joints that move linearly like the prismatic joint of a gripper), -# their nominal range is [0, 100] %. For instance, for Aloha gripper, 0% is fully -# closed, and 100% is fully open. To account for slight calibration issue, we allow up to -# [-10, 110] until an error is raised. -LOWER_BOUND_LINEAR = -10 -UPPER_BOUND_LINEAR = 110 - -HALF_TURN_DEGREE = 180 - -# https://emanual.robotis.com/docs/en/dxl/x/xl330-m077 -# https://emanual.robotis.com/docs/en/dxl/x/xl330-m288 -# https://emanual.robotis.com/docs/en/dxl/x/xl430-w250 -# https://emanual.robotis.com/docs/en/dxl/x/xm430-w350 -# https://emanual.robotis.com/docs/en/dxl/x/xm540-w270 -# https://emanual.robotis.com/docs/en/dxl/x/xc430-w150 - -# data_name: (address, size_byte) -X_SERIES_CONTROL_TABLE = { - "Model_Number": (0, 2), - "Model_Information": (2, 4), - "Firmware_Version": (6, 1), - "ID": (7, 1), - "Baud_Rate": (8, 1), - "Return_Delay_Time": (9, 1), - "Drive_Mode": (10, 1), - "Operating_Mode": (11, 1), - "Secondary_ID": (12, 1), - "Protocol_Type": (13, 1), - "Homing_Offset": (20, 4), - "Moving_Threshold": (24, 4), - "Temperature_Limit": (31, 1), - "Max_Voltage_Limit": (32, 2), - "Min_Voltage_Limit": (34, 2), - "PWM_Limit": (36, 2), - "Current_Limit": (38, 2), - "Acceleration_Limit": (40, 4), - "Velocity_Limit": (44, 4), - "Max_Position_Limit": (48, 4), - "Min_Position_Limit": (52, 4), - "Shutdown": (63, 1), - "Torque_Enable": (64, 1), - "LED": (65, 1), - "Status_Return_Level": (68, 1), - "Registered_Instruction": (69, 1), - "Hardware_Error_Status": (70, 1), - "Velocity_I_Gain": (76, 2), - "Velocity_P_Gain": (78, 2), - "Position_D_Gain": (80, 2), - "Position_I_Gain": (82, 2), - "Position_P_Gain": (84, 2), - "Feedforward_2nd_Gain": (88, 2), - "Feedforward_1st_Gain": (90, 2), - "Bus_Watchdog": (98, 1), - "Goal_PWM": (100, 2), - "Goal_Current": (102, 2), - "Goal_Velocity": (104, 4), - "Profile_Acceleration": (108, 4), - "Profile_Velocity": (112, 4), - "Goal_Position": (116, 4), - "Realtime_Tick": (120, 2), - "Moving": (122, 1), - "Moving_Status": (123, 1), - "Present_PWM": (124, 2), - "Present_Current": (126, 2), - "Present_Velocity": (128, 4), - "Present_Position": (132, 4), - "Velocity_Trajectory": (136, 4), - "Position_Trajectory": (140, 4), - "Present_Input_Voltage": (144, 2), - "Present_Temperature": (146, 1), -} - -X_SERIES_BAUDRATE_TABLE = { - 0: 9_600, - 1: 57_600, - 2: 115_200, - 3: 1_000_000, - 4: 2_000_000, - 5: 3_000_000, - 6: 4_000_000, -} - -CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"] -CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"] - -MODEL_CONTROL_TABLE = { - "x_series": X_SERIES_CONTROL_TABLE, - "xl330-m077": X_SERIES_CONTROL_TABLE, - "xl330-m288": X_SERIES_CONTROL_TABLE, - "xl430-w250": X_SERIES_CONTROL_TABLE, - "xm430-w350": X_SERIES_CONTROL_TABLE, - "xm540-w270": X_SERIES_CONTROL_TABLE, - "xc430-w150": X_SERIES_CONTROL_TABLE, -} - -MODEL_RESOLUTION = { - "x_series": 4096, - "xl330-m077": 4096, - "xl330-m288": 4096, - "xl430-w250": 4096, - "xm430-w350": 4096, - "xm540-w270": 4096, - "xc430-w150": 4096, -} - -MODEL_BAUDRATE_TABLE = { - "x_series": X_SERIES_BAUDRATE_TABLE, - "xl330-m077": X_SERIES_BAUDRATE_TABLE, - "xl330-m288": X_SERIES_BAUDRATE_TABLE, - "xl430-w250": X_SERIES_BAUDRATE_TABLE, - "xm430-w350": X_SERIES_BAUDRATE_TABLE, - "xm540-w270": X_SERIES_BAUDRATE_TABLE, - "xc430-w150": X_SERIES_BAUDRATE_TABLE, -} - -NUM_READ_RETRY = 10 -NUM_WRITE_RETRY = 10 - - -def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray: - """This function converts the degree range to the step range for indicating motors rotation. - It assumes a motor achieves a full rotation by going from -180 degree position to +180. - The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation. - """ - resolutions = [MODEL_RESOLUTION[model] for model in models] - steps = degrees / 180 * np.array(resolutions) / 2 - steps = steps.astype(int) - return steps - - -def convert_to_bytes(value, bytes, mock=False): - if mock: - return value - - import dynamixel_sdk as dxl - - # Note: No need to convert back into unsigned int, since this byte preprocessing - # already handles it for us. - if bytes == 1: - data = [ - dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)), - ] - elif bytes == 2: - data = [ - dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)), - dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)), - ] - elif bytes == 4: - data = [ - dxl.DXL_LOBYTE(dxl.DXL_LOWORD(value)), - dxl.DXL_HIBYTE(dxl.DXL_LOWORD(value)), - dxl.DXL_LOBYTE(dxl.DXL_HIWORD(value)), - dxl.DXL_HIBYTE(dxl.DXL_HIWORD(value)), - ] - else: - raise NotImplementedError( - f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but " - f"{bytes} is provided instead." - ) - return data - - -def get_group_sync_key(data_name, motor_names): - group_key = f"{data_name}_" + "_".join(motor_names) - return group_key - - -def get_result_name(fn_name, data_name, motor_names): - group_key = get_group_sync_key(data_name, motor_names) - rslt_name = f"{fn_name}_{group_key}" - return rslt_name - - -def get_queue_name(fn_name, data_name, motor_names): - group_key = get_group_sync_key(data_name, motor_names) - queue_name = f"{fn_name}_{group_key}" - return queue_name - - -def get_log_name(var_name, fn_name, data_name, motor_names): - group_key = get_group_sync_key(data_name, motor_names) - log_name = f"{var_name}_{fn_name}_{group_key}" - return log_name - - -def assert_same_address(model_ctrl_table, motor_models, data_name): - all_addr = [] - all_bytes = [] - for model in motor_models: - addr, bytes = model_ctrl_table[model][data_name] - all_addr.append(addr) - all_bytes.append(bytes) - - if len(set(all_addr)) != 1: - raise NotImplementedError( - f"At least two motor models use a different address for `data_name`='{data_name}' ({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer." - ) - - if len(set(all_bytes)) != 1: - raise NotImplementedError( - f"At least two motor models use a different bytes representation for `data_name`='{data_name}' ({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer." - ) - - -class TorqueMode(enum.Enum): - ENABLED = 1 - DISABLED = 0 - - -class DriveMode(enum.Enum): - NON_INVERTED = 0 - INVERTED = 1 - - -class CalibrationMode(enum.Enum): - # Joints with rotational motions are expressed in degrees in nominal range of [-180, 180] - DEGREE = 0 - # Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100] - LINEAR = 1 - - -class JointOutOfRangeError(Exception): - def __init__(self, message="Joint is out of range"): - self.message = message - super().__init__(self.message) - - -class DynamixelMotorsBus: - """ - The DynamixelMotorsBus class allows to efficiently read and write to the attached motors. It relies on - the python dynamixel sdk to communicate with the motors. For more info, see the [Dynamixel SDK Documentation](https://emanual.robotis.com/docs/en/software/dynamixel/dynamixel_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20). - - A DynamixelMotorsBus instance requires a port (e.g. `DynamixelMotorsBus(port="/dev/tty.usbmodem575E0031751"`)). - To find the port, you can run our utility script: - ```bash - python lerobot/scripts/find_motors_bus_port.py - >>> Finding all available ports for the MotorBus. - >>> ['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] - >>> Remove the usb cable from your DynamixelMotorsBus and press Enter when done. - >>> The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751. - >>> Reconnect the usb cable. - ``` - - Example of usage for 1 motor connected to the bus: - ```python - motor_name = "gripper" - motor_index = 6 - motor_model = "xl330-m288" - - config = DynamixelMotorsBusConfig( - port="/dev/tty.usbmodem575E0031751", - motors={motor_name: (motor_index, motor_model)}, - ) - motors_bus = DynamixelMotorsBus(config) - motors_bus.connect() - - position = motors_bus.read("Present_Position") - - # move from a few motor steps as an example - few_steps = 30 - motors_bus.write("Goal_Position", position + few_steps) - - # when done, consider disconnecting - motors_bus.disconnect() - ``` - """ - - def __init__( - self, - config: DynamixelMotorsBusConfig, - ): - self.port = config.port - self.motors = config.motors - self.mock = config.mock - - self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE) - self.model_resolution = deepcopy(MODEL_RESOLUTION) - - self.port_handler = None - self.packet_handler = None - self.calibration = None - self.is_connected = False - self.group_readers = {} - self.group_writers = {} - self.logs = {} - - def connect(self): - if self.is_connected: - raise RobotDeviceAlreadyConnectedError( - f"DynamixelMotorsBus({self.port}) is already connected. Do not call `motors_bus.connect()` twice." - ) - - if self.mock: - import tests.motors.mock_dynamixel_sdk as dxl - else: - import dynamixel_sdk as dxl - - self.port_handler = dxl.PortHandler(self.port) - self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION) - - try: - if not self.port_handler.openPort(): - raise OSError(f"Failed to open port '{self.port}'.") - except Exception: - traceback.print_exc() - print( - "\nTry running `python lerobot/scripts/find_motors_bus_port.py` to make sure you are using the correct port.\n" - ) - raise - - # Allow to read and write - self.is_connected = True - - self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS) - - def reconnect(self): - if self.mock: - import tests.motors.mock_dynamixel_sdk as dxl - else: - import dynamixel_sdk as dxl - - self.port_handler = dxl.PortHandler(self.port) - self.packet_handler = dxl.PacketHandler(PROTOCOL_VERSION) - - if not self.port_handler.openPort(): - raise OSError(f"Failed to open port '{self.port}'.") - - self.is_connected = True - - def are_motors_configured(self): - # Only check the motor indices and not baudrate, since if the motor baudrates are incorrect, - # a ConnectionError will be raised anyway. - try: - return (self.motor_indices == self.read("ID")).all() - except ConnectionError as e: - print(e) - return False - - def find_motor_indices(self, possible_ids=None, num_retry=2): - if possible_ids is None: - possible_ids = range(MAX_ID_RANGE) - - indices = [] - for idx in tqdm.tqdm(possible_ids): - try: - present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0] - except ConnectionError: - continue - - if idx != present_idx: - # sanity check - raise OSError( - "Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged." - ) - indices.append(idx) - - return indices - - def set_bus_baudrate(self, baudrate): - present_bus_baudrate = self.port_handler.getBaudRate() - if present_bus_baudrate != baudrate: - print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.") - self.port_handler.setBaudRate(baudrate) - - if self.port_handler.getBaudRate() != baudrate: - raise OSError("Failed to write bus baud rate.") - - @property - def motor_names(self) -> list[str]: - return list(self.motors.keys()) - - @property - def motor_models(self) -> list[str]: - return [model for _, model in self.motors.values()] - - @property - def motor_indices(self) -> list[int]: - return [idx for idx, _ in self.motors.values()] - - def set_calibration(self, calibration: dict[str, list]): - self.calibration = calibration - - def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None): - """This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct. - - For more info, see docstring of `apply_calibration` and `autocorrect_calibration`. - """ - try: - values = self.apply_calibration(values, motor_names) - except JointOutOfRangeError as e: - print(e) - self.autocorrect_calibration(values, motor_names) - values = self.apply_calibration(values, motor_names) - return values - - def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): - """Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with - a "zero position" at 0 degree. - - Note: We say "nominal degree range" since the motors can take values outside this range. For instance, 190 degrees, if the motor - rotate more than a half a turn from the zero position. However, most motors can't rotate more than 180 degrees and will stay in this range. - - Joints values are original in [0, 2**32[ (unsigned int32). Each motor are expected to complete a full rotation - when given a goal position that is + or - their resolution. For instance, dynamixel xl330-m077 have a resolution of 4096, and - at any position in their original range, let's say the position 56734, they complete a full rotation clockwise by moving to 60830, - or anticlockwise by moving to 52638. The position in the original range is arbitrary and might change a lot between each motor. - To harmonize between motors of the same model, different robots, or even models of different brands, we propose to work - in the centered nominal degree range ]-180, 180[. - """ - if motor_names is None: - motor_names = self.motor_names - - # Convert from unsigned int32 original range [0, 2**32] to signed float32 range - values = values.astype(np.float32) - - for i, name in enumerate(motor_names): - calib_idx = self.calibration["motor_names"].index(name) - calib_mode = self.calibration["calib_mode"][calib_idx] - - if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: - drive_mode = self.calibration["drive_mode"][calib_idx] - homing_offset = self.calibration["homing_offset"][calib_idx] - _, model = self.motors[name] - resolution = self.model_resolution[model] - - # Update direction of rotation of the motor to match between leader and follower. - # In fact, the motor of the leader for a given joint can be assembled in an - # opposite direction in term of rotation than the motor of the follower on the same joint. - if drive_mode: - values[i] *= -1 - - # Convert from range [-2**31, 2**31] to - # nominal range [-resolution//2, resolution//2] (e.g. [-2048, 2048]) - values[i] += homing_offset - - # Convert from range [-resolution//2, resolution//2] to - # universal float32 centered degree range [-180, 180] - # (e.g. 2048 / (4096 // 2) * 180 = 180) - values[i] = values[i] / (resolution // 2) * HALF_TURN_DEGREE - - if (values[i] < LOWER_BOUND_DEGREE) or (values[i] > UPPER_BOUND_DEGREE): - raise JointOutOfRangeError( - f"Wrong motor position range detected for {name}. " - f"Expected to be in nominal range of [-{HALF_TURN_DEGREE}, {HALF_TURN_DEGREE}] degrees (a full rotation), " - f"with a maximum range of [{LOWER_BOUND_DEGREE}, {UPPER_BOUND_DEGREE}] degrees to account for joints that can rotate a bit more, " - f"but present value is {values[i]} degree. " - "This might be due to a cable connection issue creating an artificial 360 degrees jump in motor values. " - "You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`" - ) - - elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: - start_pos = self.calibration["start_pos"][calib_idx] - end_pos = self.calibration["end_pos"][calib_idx] - - # Rescale the present position to a nominal range [0, 100] %, - # useful for joints with linear motions like Aloha gripper - values[i] = (values[i] - start_pos) / (end_pos - start_pos) * 100 - - if (values[i] < LOWER_BOUND_LINEAR) or (values[i] > UPPER_BOUND_LINEAR): - raise JointOutOfRangeError( - f"Wrong motor position range detected for {name}. " - f"Expected to be in nominal range of [0, 100] % (a full linear translation), " - f"with a maximum range of [{LOWER_BOUND_LINEAR}, {UPPER_BOUND_LINEAR}] % to account for some imprecision during calibration, " - f"but present value is {values[i]} %. " - "This might be due to a cable connection issue creating an artificial jump in motor values. " - "You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`" - ) - - return values - - def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): - """This function automatically detects issues with values of motors after calibration, and correct for these issues. - - Some motors might have values outside of expected maximum bounds after calibration. - For instance, for a joint in degree, its value can be outside [-270, 270] degrees, which is totally unexpected given - a nominal range of [-180, 180] degrees, which represents half a turn to the left or right starting from zero position. - - Known issues: - #1: Motor value randomly shifts of a full turn, caused by hardware/connection errors. - #2: Motor internal homing offset is shifted by a full turn, caused by using default calibration (e.g Aloha). - #3: motor internal homing offset is shifted by less or more than a full turn, caused by using default calibration - or by human error during manual calibration. - - Issues #1 and #2 can be solved by shifting the calibration homing offset by a full turn. - Issue #3 will be visually detected by user and potentially captured by the safety feature `max_relative_target`, - that will slow down the motor, raise an error asking to recalibrate. Manual recalibrating will solve the issue. - - Note: A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. - """ - if motor_names is None: - motor_names = self.motor_names - - # Convert from unsigned int32 original range [0, 2**32] to signed float32 range - values = values.astype(np.float32) - - for i, name in enumerate(motor_names): - calib_idx = self.calibration["motor_names"].index(name) - calib_mode = self.calibration["calib_mode"][calib_idx] - - if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: - drive_mode = self.calibration["drive_mode"][calib_idx] - homing_offset = self.calibration["homing_offset"][calib_idx] - _, model = self.motors[name] - resolution = self.model_resolution[model] - - # Update direction of rotation of the motor to match between leader and follower. - # In fact, the motor of the leader for a given joint can be assembled in an - # opposite direction in term of rotation than the motor of the follower on the same joint. - if drive_mode: - values[i] *= -1 - - # Convert from initial range to range [-180, 180] degrees - calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE - in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE) - - # Solve this inequality to find the factor to shift the range into [-180, 180] degrees - # values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE - # - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE - # (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution - low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution - upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution - - elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: - start_pos = self.calibration["start_pos"][calib_idx] - end_pos = self.calibration["end_pos"][calib_idx] - - # Convert from initial range to range [0, 100] in % - calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100 - in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR) - - # Solve this inequality to find the factor to shift the range into [0, 100] % - # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100 - # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100 - # 0 <= (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100 <= 100 - # (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution - low_factor = (start_pos - values[i]) / resolution - upp_factor = (end_pos - values[i]) / resolution - - if not in_range: - # Get first integer between the two bounds - if low_factor < upp_factor: - factor = math.ceil(low_factor) - - if factor > upp_factor: - raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") - else: - factor = math.ceil(upp_factor) - - if factor > low_factor: - raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") - - if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: - out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" - in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" - elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: - out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" - in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" - - logging.warning( - f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, " - f"from '{out_of_range_str}' to '{in_range_str}'." - ) - - # A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. - self.calibration["homing_offset"][calib_idx] += resolution * factor - - def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): - """Inverse of `apply_calibration`.""" - if motor_names is None: - motor_names = self.motor_names - - for i, name in enumerate(motor_names): - calib_idx = self.calibration["motor_names"].index(name) - calib_mode = self.calibration["calib_mode"][calib_idx] - - if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: - drive_mode = self.calibration["drive_mode"][calib_idx] - homing_offset = self.calibration["homing_offset"][calib_idx] - _, model = self.motors[name] - resolution = self.model_resolution[model] - - # Convert from nominal 0-centered degree range [-180, 180] to - # 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096) - values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2) - - # Subtract the homing offsets to come back to actual motor range of values - # which can be arbitrary. - values[i] -= homing_offset - - # Remove drive mode, which is the rotation direction of the motor, to come back to - # actual motor rotation direction which can be arbitrary. - if drive_mode: - values[i] *= -1 - - elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: - start_pos = self.calibration["start_pos"][calib_idx] - end_pos = self.calibration["end_pos"][calib_idx] - - # Convert from nominal lnear range of [0, 100] % to - # actual motor range of values which can be arbitrary. - values[i] = values[i] / 100 * (end_pos - start_pos) + start_pos - - values = np.round(values).astype(np.int32) - return values - - def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY): - if self.mock: - import tests.motors.mock_dynamixel_sdk as dxl - else: - import dynamixel_sdk as dxl - - return_list = True - if not isinstance(motor_ids, list): - return_list = False - motor_ids = [motor_ids] - - assert_same_address(self.model_ctrl_table, self.motor_models, data_name) - addr, bytes = self.model_ctrl_table[motor_models[0]][data_name] - group = dxl.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes) - for idx in motor_ids: - group.addParam(idx) - - for _ in range(num_retry): - comm = group.txRxPacket() - if comm == dxl.COMM_SUCCESS: - break - - if comm != dxl.COMM_SUCCESS: - raise ConnectionError( - f"Read failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: " - f"{self.packet_handler.getTxRxResult(comm)}" - ) - - values = [] - for idx in motor_ids: - value = group.getData(idx, addr, bytes) - values.append(value) - - if return_list: - return values - else: - return values[0] - - def read(self, data_name, motor_names: str | list[str] | None = None): - if not self.is_connected: - raise RobotDeviceNotConnectedError( - f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`." - ) - - start_time = time.perf_counter() - - if self.mock: - import tests.motors.mock_dynamixel_sdk as dxl - else: - import dynamixel_sdk as dxl - - if motor_names is None: - motor_names = self.motor_names - - if isinstance(motor_names, str): - motor_names = [motor_names] - - motor_ids = [] - models = [] - for name in motor_names: - motor_idx, model = self.motors[name] - motor_ids.append(motor_idx) - models.append(model) - - assert_same_address(self.model_ctrl_table, models, data_name) - addr, bytes = self.model_ctrl_table[model][data_name] - group_key = get_group_sync_key(data_name, motor_names) - - if data_name not in self.group_readers: - # create new group reader - self.group_readers[group_key] = dxl.GroupSyncRead( - self.port_handler, self.packet_handler, addr, bytes - ) - for idx in motor_ids: - self.group_readers[group_key].addParam(idx) - - for _ in range(NUM_READ_RETRY): - comm = self.group_readers[group_key].txRxPacket() - if comm == dxl.COMM_SUCCESS: - break - - if comm != dxl.COMM_SUCCESS: - raise ConnectionError( - f"Read failed due to communication error on port {self.port} for group_key {group_key}: " - f"{self.packet_handler.getTxRxResult(comm)}" - ) - - values = [] - for idx in motor_ids: - value = self.group_readers[group_key].getData(idx, addr, bytes) - values.append(value) - - values = np.array(values) - - # Convert to signed int to use range [-2048, 2048] for our motor positions. - if data_name in CONVERT_UINT32_TO_INT32_REQUIRED: - values = values.astype(np.int32) - - if data_name in CALIBRATION_REQUIRED and self.calibration is not None: - values = self.apply_calibration_autocorrect(values, motor_names) - - # log the number of seconds it took to read the data from the motors - delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names) - self.logs[delta_ts_name] = time.perf_counter() - start_time - - # log the utc time at which the data was received - ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_names) - self.logs[ts_utc_name] = capture_timestamp_utc() - - return values - - def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY): - if self.mock: - import tests.motors.mock_dynamixel_sdk as dxl - else: - import dynamixel_sdk as dxl - - if not isinstance(motor_ids, list): - motor_ids = [motor_ids] - if not isinstance(values, list): - values = [values] - - assert_same_address(self.model_ctrl_table, motor_models, data_name) - addr, bytes = self.model_ctrl_table[motor_models[0]][data_name] - group = dxl.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes) - for idx, value in zip(motor_ids, values, strict=True): - data = convert_to_bytes(value, bytes, self.mock) - group.addParam(idx, data) - - for _ in range(num_retry): - comm = group.txPacket() - if comm == dxl.COMM_SUCCESS: - break - - if comm != dxl.COMM_SUCCESS: - raise ConnectionError( - f"Write failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: " - f"{self.packet_handler.getTxRxResult(comm)}" - ) - - def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None): - if not self.is_connected: - raise RobotDeviceNotConnectedError( - f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`." - ) - - start_time = time.perf_counter() - - if self.mock: - import tests.motors.mock_dynamixel_sdk as dxl - else: - import dynamixel_sdk as dxl - - if motor_names is None: - motor_names = self.motor_names - - if isinstance(motor_names, str): - motor_names = [motor_names] - - if isinstance(values, (int, float, np.integer)): - values = [int(values)] * len(motor_names) - - values = np.array(values) - - motor_ids = [] - models = [] - for name in motor_names: - motor_idx, model = self.motors[name] - motor_ids.append(motor_idx) - models.append(model) - - if data_name in CALIBRATION_REQUIRED and self.calibration is not None: - values = self.revert_calibration(values, motor_names) - - values = values.tolist() - - assert_same_address(self.model_ctrl_table, models, data_name) - addr, bytes = self.model_ctrl_table[model][data_name] - group_key = get_group_sync_key(data_name, motor_names) - - init_group = data_name not in self.group_readers - if init_group: - self.group_writers[group_key] = dxl.GroupSyncWrite( - self.port_handler, self.packet_handler, addr, bytes - ) - - for idx, value in zip(motor_ids, values, strict=True): - data = convert_to_bytes(value, bytes, self.mock) - if init_group: - self.group_writers[group_key].addParam(idx, data) - else: - self.group_writers[group_key].changeParam(idx, data) - - comm = self.group_writers[group_key].txPacket() - if comm != dxl.COMM_SUCCESS: - raise ConnectionError( - f"Write failed due to communication error on port {self.port} for group_key {group_key}: " - f"{self.packet_handler.getTxRxResult(comm)}" - ) - - # log the number of seconds it took to write the data to the motors - delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names) - self.logs[delta_ts_name] = time.perf_counter() - start_time - - # TODO(rcadene): should we log the time before sending the write command? - # log the utc time when the write has been completed - ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names) - self.logs[ts_utc_name] = capture_timestamp_utc() - - def disconnect(self): - if not self.is_connected: - raise RobotDeviceNotConnectedError( - f"DynamixelMotorsBus({self.port}) is not connected. Try running `motors_bus.connect()` first." - ) - - if self.port_handler is not None: - self.port_handler.closePort() - self.port_handler = None - - self.packet_handler = None - self.group_readers = {} - self.group_writers = {} - self.is_connected = False - - def __del__(self): - if getattr(self, "is_connected", False): - self.disconnect() diff --git a/lerobot/common/robot_devices/motors/feetech.py b/lerobot/common/robot_devices/motors/feetech.py deleted file mode 100644 index 64c7f413d..000000000 --- a/lerobot/common/robot_devices/motors/feetech.py +++ /dev/null @@ -1,898 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import enum -import logging -import math -import time -import traceback -from copy import deepcopy - -import numpy as np -import tqdm - -from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig -from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError -from lerobot.common.utils.utils import capture_timestamp_utc - -PROTOCOL_VERSION = 0 -BAUDRATE = 1_000_000 -TIMEOUT_MS = 1000 - -MAX_ID_RANGE = 252 - -# The following bounds define the lower and upper joints range (after calibration). -# For joints in degree (i.e. revolute joints), their nominal range is [-180, 180] degrees -# which corresponds to a half rotation on the left and half rotation on the right. -# Some joints might require higher range, so we allow up to [-270, 270] degrees until -# an error is raised. -LOWER_BOUND_DEGREE = -270 -UPPER_BOUND_DEGREE = 270 -# For joints in percentage (i.e. joints that move linearly like the prismatic joint of a gripper), -# their nominal range is [0, 100] %. For instance, for Aloha gripper, 0% is fully -# closed, and 100% is fully open. To account for slight calibration issue, we allow up to -# [-10, 110] until an error is raised. -LOWER_BOUND_LINEAR = -10 -UPPER_BOUND_LINEAR = 110 - -HALF_TURN_DEGREE = 180 - - -# See this link for STS3215 Memory Table: -# https://docs.google.com/spreadsheets/d/1GVs7W1VS1PqdhA1nW-abeyAHhTUxKUdR/edit?usp=sharing&ouid=116566590112741600240&rtpof=true&sd=true -# data_name: (address, size_byte) -SCS_SERIES_CONTROL_TABLE = { - "Model": (3, 2), - "ID": (5, 1), - "Baud_Rate": (6, 1), - "Return_Delay": (7, 1), - "Response_Status_Level": (8, 1), - "Min_Angle_Limit": (9, 2), - "Max_Angle_Limit": (11, 2), - "Max_Temperature_Limit": (13, 1), - "Max_Voltage_Limit": (14, 1), - "Min_Voltage_Limit": (15, 1), - "Max_Torque_Limit": (16, 2), - "Phase": (18, 1), - "Unloading_Condition": (19, 1), - "LED_Alarm_Condition": (20, 1), - "P_Coefficient": (21, 1), - "D_Coefficient": (22, 1), - "I_Coefficient": (23, 1), - "Minimum_Startup_Force": (24, 2), - "CW_Dead_Zone": (26, 1), - "CCW_Dead_Zone": (27, 1), - "Protection_Current": (28, 2), - "Angular_Resolution": (30, 1), - "Offset": (31, 2), - "Mode": (33, 1), - "Protective_Torque": (34, 1), - "Protection_Time": (35, 1), - "Overload_Torque": (36, 1), - "Speed_closed_loop_P_proportional_coefficient": (37, 1), - "Over_Current_Protection_Time": (38, 1), - "Velocity_closed_loop_I_integral_coefficient": (39, 1), - "Torque_Enable": (40, 1), - "Acceleration": (41, 1), - "Goal_Position": (42, 2), - "Goal_Time": (44, 2), - "Goal_Speed": (46, 2), - "Torque_Limit": (48, 2), - "Lock": (55, 1), - "Present_Position": (56, 2), - "Present_Speed": (58, 2), - "Present_Load": (60, 2), - "Present_Voltage": (62, 1), - "Present_Temperature": (63, 1), - "Status": (65, 1), - "Moving": (66, 1), - "Present_Current": (69, 2), - # Not in the Memory Table - "Maximum_Acceleration": (85, 2), -} - -SCS_SERIES_BAUDRATE_TABLE = { - 0: 1_000_000, - 1: 500_000, - 2: 250_000, - 3: 128_000, - 4: 115_200, - 5: 57_600, - 6: 38_400, - 7: 19_200, -} - -CALIBRATION_REQUIRED = ["Goal_Position", "Present_Position"] -CONVERT_UINT32_TO_INT32_REQUIRED = ["Goal_Position", "Present_Position"] - - -MODEL_CONTROL_TABLE = { - "scs_series": SCS_SERIES_CONTROL_TABLE, - "sts3215": SCS_SERIES_CONTROL_TABLE, -} - -MODEL_RESOLUTION = { - "scs_series": 4096, - "sts3215": 4096, -} - -MODEL_BAUDRATE_TABLE = { - "scs_series": SCS_SERIES_BAUDRATE_TABLE, - "sts3215": SCS_SERIES_BAUDRATE_TABLE, -} - -# High number of retries is needed for feetech compared to dynamixel motors. -NUM_READ_RETRY = 20 -NUM_WRITE_RETRY = 20 - - -def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray: - """This function converts the degree range to the step range for indicating motors rotation. - It assumes a motor achieves a full rotation by going from -180 degree position to +180. - The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation. - """ - resolutions = [MODEL_RESOLUTION[model] for model in models] - steps = degrees / 180 * np.array(resolutions) / 2 - steps = steps.astype(int) - return steps - - -def convert_to_bytes(value, bytes, mock=False): - if mock: - return value - - import scservo_sdk as scs - - # Note: No need to convert back into unsigned int, since this byte preprocessing - # already handles it for us. - if bytes == 1: - data = [ - scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), - ] - elif bytes == 2: - data = [ - scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), - scs.SCS_HIBYTE(scs.SCS_LOWORD(value)), - ] - elif bytes == 4: - data = [ - scs.SCS_LOBYTE(scs.SCS_LOWORD(value)), - scs.SCS_HIBYTE(scs.SCS_LOWORD(value)), - scs.SCS_LOBYTE(scs.SCS_HIWORD(value)), - scs.SCS_HIBYTE(scs.SCS_HIWORD(value)), - ] - else: - raise NotImplementedError( - f"Value of the number of bytes to be sent is expected to be in [1, 2, 4], but " - f"{bytes} is provided instead." - ) - return data - - -def get_group_sync_key(data_name, motor_names): - group_key = f"{data_name}_" + "_".join(motor_names) - return group_key - - -def get_result_name(fn_name, data_name, motor_names): - group_key = get_group_sync_key(data_name, motor_names) - rslt_name = f"{fn_name}_{group_key}" - return rslt_name - - -def get_queue_name(fn_name, data_name, motor_names): - group_key = get_group_sync_key(data_name, motor_names) - queue_name = f"{fn_name}_{group_key}" - return queue_name - - -def get_log_name(var_name, fn_name, data_name, motor_names): - group_key = get_group_sync_key(data_name, motor_names) - log_name = f"{var_name}_{fn_name}_{group_key}" - return log_name - - -def assert_same_address(model_ctrl_table, motor_models, data_name): - all_addr = [] - all_bytes = [] - for model in motor_models: - addr, bytes = model_ctrl_table[model][data_name] - all_addr.append(addr) - all_bytes.append(bytes) - - if len(set(all_addr)) != 1: - raise NotImplementedError( - f"At least two motor models use a different address for `data_name`='{data_name}' ({list(zip(motor_models, all_addr, strict=False))}). Contact a LeRobot maintainer." - ) - - if len(set(all_bytes)) != 1: - raise NotImplementedError( - f"At least two motor models use a different bytes representation for `data_name`='{data_name}' ({list(zip(motor_models, all_bytes, strict=False))}). Contact a LeRobot maintainer." - ) - - -class TorqueMode(enum.Enum): - ENABLED = 1 - DISABLED = 0 - - -class DriveMode(enum.Enum): - NON_INVERTED = 0 - INVERTED = 1 - - -class CalibrationMode(enum.Enum): - # Joints with rotational motions are expressed in degrees in nominal range of [-180, 180] - DEGREE = 0 - # Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100] - LINEAR = 1 - - -class JointOutOfRangeError(Exception): - def __init__(self, message="Joint is out of range"): - self.message = message - super().__init__(self.message) - - -class FeetechMotorsBus: - """ - The FeetechMotorsBus class allows to efficiently read and write to the attached motors. It relies on - the python feetech sdk to communicate with the motors. For more info, see the [feetech SDK Documentation](https://emanual.robotis.com/docs/en/software/feetech/feetech_sdk/sample_code/python_read_write_protocol_2_0/#python-read-write-protocol-20). - - A FeetechMotorsBus instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)). - To find the port, you can run our utility script: - ```bash - python lerobot/scripts/find_motors_bus_port.py - >>> Finding all available ports for the MotorsBus. - >>> ['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] - >>> Remove the usb cable from your FeetechMotorsBus and press Enter when done. - >>> The port of this FeetechMotorsBus is /dev/tty.usbmodem575E0031751. - >>> Reconnect the usb cable. - ``` - - Example of usage for 1 motor connected to the bus: - ```python - motor_name = "gripper" - motor_index = 6 - motor_model = "sts3215" - - config = FeetechMotorsBusConfig( - port="/dev/tty.usbmodem575E0031751", - motors={motor_name: (motor_index, motor_model)}, - ) - motors_bus = FeetechMotorsBus(config) - motors_bus.connect() - - position = motors_bus.read("Present_Position") - - # move from a few motor steps as an example - few_steps = 30 - motors_bus.write("Goal_Position", position + few_steps) - - # when done, consider disconnecting - motors_bus.disconnect() - ``` - """ - - def __init__( - self, - config: FeetechMotorsBusConfig, - ): - self.port = config.port - self.motors = config.motors - self.mock = config.mock - - self.model_ctrl_table = deepcopy(MODEL_CONTROL_TABLE) - self.model_resolution = deepcopy(MODEL_RESOLUTION) - - self.port_handler = None - self.packet_handler = None - self.calibration = None - self.is_connected = False - self.group_readers = {} - self.group_writers = {} - self.logs = {} - - self.track_positions = {} - - def connect(self): - if self.is_connected: - raise RobotDeviceAlreadyConnectedError( - f"FeetechMotorsBus({self.port}) is already connected. Do not call `motors_bus.connect()` twice." - ) - - if self.mock: - import tests.motors.mock_scservo_sdk as scs - else: - import scservo_sdk as scs - - self.port_handler = scs.PortHandler(self.port) - self.packet_handler = scs.PacketHandler(PROTOCOL_VERSION) - - try: - if not self.port_handler.openPort(): - raise OSError(f"Failed to open port '{self.port}'.") - except Exception: - traceback.print_exc() - print( - "\nTry running `python lerobot/scripts/find_motors_bus_port.py` to make sure you are using the correct port.\n" - ) - raise - - # Allow to read and write - self.is_connected = True - - self.port_handler.setPacketTimeoutMillis(TIMEOUT_MS) - - def reconnect(self): - if self.mock: - import tests.motors.mock_scservo_sdk as scs - else: - import scservo_sdk as scs - - self.port_handler = scs.PortHandler(self.port) - self.packet_handler = scs.PacketHandler(PROTOCOL_VERSION) - - if not self.port_handler.openPort(): - raise OSError(f"Failed to open port '{self.port}'.") - - self.is_connected = True - - def are_motors_configured(self): - # Only check the motor indices and not baudrate, since if the motor baudrates are incorrect, - # a ConnectionError will be raised anyway. - try: - return (self.motor_indices == self.read("ID")).all() - except ConnectionError as e: - print(e) - return False - - def find_motor_indices(self, possible_ids=None, num_retry=2): - if possible_ids is None: - possible_ids = range(MAX_ID_RANGE) - - indices = [] - for idx in tqdm.tqdm(possible_ids): - try: - present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0] - except ConnectionError: - continue - - if idx != present_idx: - # sanity check - raise OSError( - "Motor index used to communicate through the bus is not the same as the one present in the motor memory. The motor memory might be damaged." - ) - indices.append(idx) - - return indices - - def set_bus_baudrate(self, baudrate): - present_bus_baudrate = self.port_handler.getBaudRate() - if present_bus_baudrate != baudrate: - print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.") - self.port_handler.setBaudRate(baudrate) - - if self.port_handler.getBaudRate() != baudrate: - raise OSError("Failed to write bus baud rate.") - - @property - def motor_names(self) -> list[str]: - return list(self.motors.keys()) - - @property - def motor_models(self) -> list[str]: - return [model for _, model in self.motors.values()] - - @property - def motor_indices(self) -> list[int]: - return [idx for idx, _ in self.motors.values()] - - def set_calibration(self, calibration: dict[str, list]): - self.calibration = calibration - - def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None): - """This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct. - - For more info, see docstring of `apply_calibration` and `autocorrect_calibration`. - """ - try: - values = self.apply_calibration(values, motor_names) - except JointOutOfRangeError as e: - print(e) - self.autocorrect_calibration(values, motor_names) - values = self.apply_calibration(values, motor_names) - return values - - def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): - """Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with - a "zero position" at 0 degree. - - Note: We say "nominal degree range" since the motors can take values outside this range. For instance, 190 degrees, if the motor - rotate more than a half a turn from the zero position. However, most motors can't rotate more than 180 degrees and will stay in this range. - - Joints values are original in [0, 2**32[ (unsigned int32). Each motor are expected to complete a full rotation - when given a goal position that is + or - their resolution. For instance, feetech xl330-m077 have a resolution of 4096, and - at any position in their original range, let's say the position 56734, they complete a full rotation clockwise by moving to 60830, - or anticlockwise by moving to 52638. The position in the original range is arbitrary and might change a lot between each motor. - To harmonize between motors of the same model, different robots, or even models of different brands, we propose to work - in the centered nominal degree range ]-180, 180[. - """ - if motor_names is None: - motor_names = self.motor_names - - # Convert from unsigned int32 original range [0, 2**32] to signed float32 range - values = values.astype(np.float32) - - for i, name in enumerate(motor_names): - calib_idx = self.calibration["motor_names"].index(name) - calib_mode = self.calibration["calib_mode"][calib_idx] - - if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: - drive_mode = self.calibration["drive_mode"][calib_idx] - homing_offset = self.calibration["homing_offset"][calib_idx] - _, model = self.motors[name] - resolution = self.model_resolution[model] - - # Update direction of rotation of the motor to match between leader and follower. - # In fact, the motor of the leader for a given joint can be assembled in an - # opposite direction in term of rotation than the motor of the follower on the same joint. - if drive_mode: - values[i] *= -1 - - # Convert from range [-2**31, 2**31[ to - # nominal range ]-resolution, resolution[ (e.g. ]-2048, 2048[) - values[i] += homing_offset - - # Convert from range ]-resolution, resolution[ to - # universal float32 centered degree range ]-180, 180[ - values[i] = values[i] / (resolution // 2) * HALF_TURN_DEGREE - - if (values[i] < LOWER_BOUND_DEGREE) or (values[i] > UPPER_BOUND_DEGREE): - raise JointOutOfRangeError( - f"Wrong motor position range detected for {name}. " - f"Expected to be in nominal range of [-{HALF_TURN_DEGREE}, {HALF_TURN_DEGREE}] degrees (a full rotation), " - f"with a maximum range of [{LOWER_BOUND_DEGREE}, {UPPER_BOUND_DEGREE}] degrees to account for joints that can rotate a bit more, " - f"but present value is {values[i]} degree. " - "This might be due to a cable connection issue creating an artificial 360 degrees jump in motor values. " - "You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`" - ) - - elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: - start_pos = self.calibration["start_pos"][calib_idx] - end_pos = self.calibration["end_pos"][calib_idx] - - # Rescale the present position to a nominal range [0, 100] %, - # useful for joints with linear motions like Aloha gripper - values[i] = (values[i] - start_pos) / (end_pos - start_pos) * 100 - - if (values[i] < LOWER_BOUND_LINEAR) or (values[i] > UPPER_BOUND_LINEAR): - raise JointOutOfRangeError( - f"Wrong motor position range detected for {name}. " - f"Expected to be in nominal range of [0, 100] % (a full linear translation), " - f"with a maximum range of [{LOWER_BOUND_LINEAR}, {UPPER_BOUND_LINEAR}] % to account for some imprecision during calibration, " - f"but present value is {values[i]} %. " - "This might be due to a cable connection issue creating an artificial jump in motor values. " - "You need to recalibrate by running: `python lerobot/scripts/control_robot.py calibrate`" - ) - - return values - - def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): - """This function automatically detects issues with values of motors after calibration, and correct for these issues. - - Some motors might have values outside of expected maximum bounds after calibration. - For instance, for a joint in degree, its value can be outside [-270, 270] degrees, which is totally unexpected given - a nominal range of [-180, 180] degrees, which represents half a turn to the left or right starting from zero position. - - Known issues: - #1: Motor value randomly shifts of a full turn, caused by hardware/connection errors. - #2: Motor internal homing offset is shifted of a full turn, caused by using default calibration (e.g Aloha). - #3: motor internal homing offset is shifted of less or more than a full turn, caused by using default calibration - or by human error during manual calibration. - - Issues #1 and #2 can be solved by shifting the calibration homing offset by a full turn. - Issue #3 will be visually detected by user and potentially captured by the safety feature `max_relative_target`, - that will slow down the motor, raise an error asking to recalibrate. Manual recalibrating will solve the issue. - - Note: A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. - """ - if motor_names is None: - motor_names = self.motor_names - - # Convert from unsigned int32 original range [0, 2**32] to signed float32 range - values = values.astype(np.float32) - - for i, name in enumerate(motor_names): - calib_idx = self.calibration["motor_names"].index(name) - calib_mode = self.calibration["calib_mode"][calib_idx] - - if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: - drive_mode = self.calibration["drive_mode"][calib_idx] - homing_offset = self.calibration["homing_offset"][calib_idx] - _, model = self.motors[name] - resolution = self.model_resolution[model] - - if drive_mode: - values[i] *= -1 - - # Convert from initial range to range [-180, 180] degrees - calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE - in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE) - - # Solve this inequality to find the factor to shift the range into [-180, 180] degrees - # values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE - # - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE - # (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution - low_factor = ( - -HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset - ) / resolution - upp_factor = ( - HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset - ) / resolution - - elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: - start_pos = self.calibration["start_pos"][calib_idx] - end_pos = self.calibration["end_pos"][calib_idx] - - # Convert from initial range to range [0, 100] in % - calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100 - in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR) - - # Solve this inequality to find the factor to shift the range into [0, 100] % - # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100 - # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100 - # 0 <= (values[i] - start_pos + resolution * factor) / (end_pos - start_pos) * 100 <= 100 - # (start_pos - values[i]) / resolution <= factor <= (end_pos - values[i]) / resolution - low_factor = (start_pos - values[i]) / resolution - upp_factor = (end_pos - values[i]) / resolution - - if not in_range: - # Get first integer between the two bounds - if low_factor < upp_factor: - factor = math.ceil(low_factor) - - if factor > upp_factor: - raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") - else: - factor = math.ceil(upp_factor) - - if factor > low_factor: - raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") - - if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: - out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" - in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" - elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: - out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" - in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" - - logging.warning( - f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, " - f"from '{out_of_range_str}' to '{in_range_str}'." - ) - - # A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. - self.calibration["homing_offset"][calib_idx] += resolution * factor - - def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): - """Inverse of `apply_calibration`.""" - if motor_names is None: - motor_names = self.motor_names - - for i, name in enumerate(motor_names): - calib_idx = self.calibration["motor_names"].index(name) - calib_mode = self.calibration["calib_mode"][calib_idx] - - if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: - drive_mode = self.calibration["drive_mode"][calib_idx] - homing_offset = self.calibration["homing_offset"][calib_idx] - _, model = self.motors[name] - resolution = self.model_resolution[model] - - # Convert from nominal 0-centered degree range [-180, 180] to - # 0-centered resolution range (e.g. [-2048, 2048] for resolution=4096) - values[i] = values[i] / HALF_TURN_DEGREE * (resolution // 2) - - # Subtract the homing offsets to come back to actual motor range of values - # which can be arbitrary. - values[i] -= homing_offset - - # Remove drive mode, which is the rotation direction of the motor, to come back to - # actual motor rotation direction which can be arbitrary. - if drive_mode: - values[i] *= -1 - - elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: - start_pos = self.calibration["start_pos"][calib_idx] - end_pos = self.calibration["end_pos"][calib_idx] - - # Convert from nominal lnear range of [0, 100] % to - # actual motor range of values which can be arbitrary. - values[i] = values[i] / 100 * (end_pos - start_pos) + start_pos - - values = np.round(values).astype(np.int32) - return values - - def avoid_rotation_reset(self, values, motor_names, data_name): - if data_name not in self.track_positions: - self.track_positions[data_name] = { - "prev": [None] * len(self.motor_names), - # Assume False at initialization - "below_zero": [False] * len(self.motor_names), - "above_max": [False] * len(self.motor_names), - } - - track = self.track_positions[data_name] - - if motor_names is None: - motor_names = self.motor_names - - for i, name in enumerate(motor_names): - idx = self.motor_names.index(name) - - if track["prev"][idx] is None: - track["prev"][idx] = values[i] - continue - - # Detect a full rotation occurred - if abs(track["prev"][idx] - values[i]) > 2048: - # Position went below 0 and got reset to 4095 - if track["prev"][idx] < values[i]: - # So we set negative value by adding a full rotation - values[i] -= 4096 - - # Position went above 4095 and got reset to 0 - elif track["prev"][idx] > values[i]: - # So we add a full rotation - values[i] += 4096 - - track["prev"][idx] = values[i] - - return values - - def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY): - if self.mock: - import tests.motors.mock_scservo_sdk as scs - else: - import scservo_sdk as scs - - return_list = True - if not isinstance(motor_ids, list): - return_list = False - motor_ids = [motor_ids] - - assert_same_address(self.model_ctrl_table, self.motor_models, data_name) - addr, bytes = self.model_ctrl_table[motor_models[0]][data_name] - group = scs.GroupSyncRead(self.port_handler, self.packet_handler, addr, bytes) - for idx in motor_ids: - group.addParam(idx) - - for _ in range(num_retry): - comm = group.txRxPacket() - if comm == scs.COMM_SUCCESS: - break - - if comm != scs.COMM_SUCCESS: - raise ConnectionError( - f"Read failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: " - f"{self.packet_handler.getTxRxResult(comm)}" - ) - - values = [] - for idx in motor_ids: - value = group.getData(idx, addr, bytes) - values.append(value) - - if return_list: - return values - else: - return values[0] - - def read(self, data_name, motor_names: str | list[str] | None = None): - if self.mock: - import tests.motors.mock_scservo_sdk as scs - else: - import scservo_sdk as scs - - if not self.is_connected: - raise RobotDeviceNotConnectedError( - f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`." - ) - - start_time = time.perf_counter() - - if motor_names is None: - motor_names = self.motor_names - - if isinstance(motor_names, str): - motor_names = [motor_names] - - motor_ids = [] - models = [] - for name in motor_names: - motor_idx, model = self.motors[name] - motor_ids.append(motor_idx) - models.append(model) - - assert_same_address(self.model_ctrl_table, models, data_name) - addr, bytes = self.model_ctrl_table[model][data_name] - group_key = get_group_sync_key(data_name, motor_names) - - if data_name not in self.group_readers: - # Very Important to flush the buffer! - self.port_handler.ser.reset_output_buffer() - self.port_handler.ser.reset_input_buffer() - - # create new group reader - self.group_readers[group_key] = scs.GroupSyncRead( - self.port_handler, self.packet_handler, addr, bytes - ) - for idx in motor_ids: - self.group_readers[group_key].addParam(idx) - - for _ in range(NUM_READ_RETRY): - comm = self.group_readers[group_key].txRxPacket() - if comm == scs.COMM_SUCCESS: - break - - if comm != scs.COMM_SUCCESS: - raise ConnectionError( - f"Read failed due to communication error on port {self.port} for group_key {group_key}: " - f"{self.packet_handler.getTxRxResult(comm)}" - ) - - values = [] - for idx in motor_ids: - value = self.group_readers[group_key].getData(idx, addr, bytes) - values.append(value) - - values = np.array(values) - - # Convert to signed int to use range [-2048, 2048] for our motor positions. - if data_name in CONVERT_UINT32_TO_INT32_REQUIRED: - values = values.astype(np.int32) - - if data_name in CALIBRATION_REQUIRED: - values = self.avoid_rotation_reset(values, motor_names, data_name) - - if data_name in CALIBRATION_REQUIRED and self.calibration is not None: - values = self.apply_calibration_autocorrect(values, motor_names) - - # log the number of seconds it took to read the data from the motors - delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names) - self.logs[delta_ts_name] = time.perf_counter() - start_time - - # log the utc time at which the data was received - ts_utc_name = get_log_name("timestamp_utc", "read", data_name, motor_names) - self.logs[ts_utc_name] = capture_timestamp_utc() - - return values - - def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY): - if self.mock: - import tests.motors.mock_scservo_sdk as scs - else: - import scservo_sdk as scs - - if not isinstance(motor_ids, list): - motor_ids = [motor_ids] - if not isinstance(values, list): - values = [values] - - assert_same_address(self.model_ctrl_table, motor_models, data_name) - addr, bytes = self.model_ctrl_table[motor_models[0]][data_name] - group = scs.GroupSyncWrite(self.port_handler, self.packet_handler, addr, bytes) - for idx, value in zip(motor_ids, values, strict=True): - data = convert_to_bytes(value, bytes, self.mock) - group.addParam(idx, data) - - for _ in range(num_retry): - comm = group.txPacket() - if comm == scs.COMM_SUCCESS: - break - - if comm != scs.COMM_SUCCESS: - raise ConnectionError( - f"Write failed due to communication error on port {self.port_handler.port_name} for indices {motor_ids}: " - f"{self.packet_handler.getTxRxResult(comm)}" - ) - - def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None): - if not self.is_connected: - raise RobotDeviceNotConnectedError( - f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`." - ) - - start_time = time.perf_counter() - - if self.mock: - import tests.motors.mock_scservo_sdk as scs - else: - import scservo_sdk as scs - - if motor_names is None: - motor_names = self.motor_names - - if isinstance(motor_names, str): - motor_names = [motor_names] - - if isinstance(values, (int, float, np.integer)): - values = [int(values)] * len(motor_names) - - values = np.array(values) - - motor_ids = [] - models = [] - for name in motor_names: - motor_idx, model = self.motors[name] - motor_ids.append(motor_idx) - models.append(model) - - if data_name in CALIBRATION_REQUIRED and self.calibration is not None: - values = self.revert_calibration(values, motor_names) - - values = values.tolist() - - assert_same_address(self.model_ctrl_table, models, data_name) - addr, bytes = self.model_ctrl_table[model][data_name] - group_key = get_group_sync_key(data_name, motor_names) - - init_group = data_name not in self.group_readers - if init_group: - self.group_writers[group_key] = scs.GroupSyncWrite( - self.port_handler, self.packet_handler, addr, bytes - ) - - for idx, value in zip(motor_ids, values, strict=True): - data = convert_to_bytes(value, bytes, self.mock) - if init_group: - self.group_writers[group_key].addParam(idx, data) - else: - self.group_writers[group_key].changeParam(idx, data) - - comm = self.group_writers[group_key].txPacket() - if comm != scs.COMM_SUCCESS: - raise ConnectionError( - f"Write failed due to communication error on port {self.port} for group_key {group_key}: " - f"{self.packet_handler.getTxRxResult(comm)}" - ) - - # log the number of seconds it took to write the data to the motors - delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names) - self.logs[delta_ts_name] = time.perf_counter() - start_time - - # TODO(rcadene): should we log the time before sending the write command? - # log the utc time when the write has been completed - ts_utc_name = get_log_name("timestamp_utc", "write", data_name, motor_names) - self.logs[ts_utc_name] = capture_timestamp_utc() - - def disconnect(self): - if not self.is_connected: - raise RobotDeviceNotConnectedError( - f"FeetechMotorsBus({self.port}) is not connected. Try running `motors_bus.connect()` first." - ) - - if self.port_handler is not None: - self.port_handler.closePort() - self.port_handler = None - - self.packet_handler = None - self.group_readers = {} - self.group_writers = {} - self.is_connected = False - - def __del__(self): - if getattr(self, "is_connected", False): - self.disconnect() diff --git a/lerobot/common/robot_devices/motors/utils.py b/lerobot/common/robot_devices/motors/utils.py deleted file mode 100644 index bd86f4c64..000000000 --- a/lerobot/common/robot_devices/motors/utils.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Protocol - -from lerobot.common.robot_devices.motors.configs import ( - DynamixelMotorsBusConfig, - FeetechMotorsBusConfig, - MotorsBusConfig, -) - - -class MotorsBus(Protocol): - def motor_names(self): ... - def set_calibration(self): ... - def apply_calibration(self): ... - def revert_calibration(self): ... - def read(self): ... - def write(self): ... - - -def make_motors_buses_from_configs(motors_bus_configs: dict[str, MotorsBusConfig]) -> list[MotorsBus]: - motors_buses = {} - - for key, cfg in motors_bus_configs.items(): - if cfg.type == "dynamixel": - from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus - - motors_buses[key] = DynamixelMotorsBus(cfg) - - elif cfg.type == "feetech": - from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus - - motors_buses[key] = FeetechMotorsBus(cfg) - - else: - raise ValueError(f"The motor type '{cfg.type}' is not valid.") - - return motors_buses - - -def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus: - if motor_type == "dynamixel": - from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus - - config = DynamixelMotorsBusConfig(**kwargs) - return DynamixelMotorsBus(config) - - elif motor_type == "feetech": - from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus - - config = FeetechMotorsBusConfig(**kwargs) - return FeetechMotorsBus(config) - - else: - raise ValueError(f"The motor type '{motor_type}' is not valid.") diff --git a/lerobot/common/robot_devices/robots/configs.py b/lerobot/common/robot_devices/robots/configs.py deleted file mode 100644 index e940b442f..000000000 --- a/lerobot/common/robot_devices/robots/configs.py +++ /dev/null @@ -1,613 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import abc -from dataclasses import dataclass, field -from typing import Sequence - -import draccus - -from lerobot.common.robot_devices.cameras.configs import ( - CameraConfig, - IntelRealSenseCameraConfig, - OpenCVCameraConfig, -) -from lerobot.common.robot_devices.motors.configs import ( - DynamixelMotorsBusConfig, - FeetechMotorsBusConfig, - MotorsBusConfig, -) - - -@dataclass -class RobotConfig(draccus.ChoiceRegistry, abc.ABC): - @property - def type(self) -> str: - return self.get_choice_name(self.__class__) - - -# TODO(rcadene, aliberts): remove ManipulatorRobotConfig abstraction -@dataclass -class ManipulatorRobotConfig(RobotConfig): - leader_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {}) - follower_arms: dict[str, MotorsBusConfig] = field(default_factory=lambda: {}) - cameras: dict[str, CameraConfig] = field(default_factory=lambda: {}) - - # Optionally limit the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a list that is the same length - # as the number of motors in your follower arms (assumes all follower arms have the same number of - # motors). - max_relative_target: list[float] | float | None = None - - # Optionally set the leader arm in torque mode with the gripper motor set to this angle. This makes it - # possible to squeeze the gripper and have it spring back to an open position on its own. If None, the - # gripper is not put in torque mode. - gripper_open_degree: float | None = None - - mock: bool = False - - def __post_init__(self): - if self.mock: - for arm in self.leader_arms.values(): - if not arm.mock: - arm.mock = True - for arm in self.follower_arms.values(): - if not arm.mock: - arm.mock = True - for cam in self.cameras.values(): - if not cam.mock: - cam.mock = True - - if self.max_relative_target is not None and isinstance(self.max_relative_target, Sequence): - for name in self.follower_arms: - if len(self.follower_arms[name].motors) != len(self.max_relative_target): - raise ValueError( - f"len(max_relative_target)={len(self.max_relative_target)} but the follower arm with name {name} has " - f"{len(self.follower_arms[name].motors)} motors. Please make sure that the " - f"`max_relative_target` list has as many parameters as there are motors per arm. " - "Note: This feature does not yet work with robots where different follower arms have " - "different numbers of motors." - ) - - -@RobotConfig.register_subclass("aloha") -@dataclass -class AlohaRobotConfig(ManipulatorRobotConfig): - # Specific to Aloha, LeRobot comes with default calibration files. Assuming the motors have been - # properly assembled, no manual calibration step is expected. If you need to run manual calibration, - # simply update this path to ".cache/calibration/aloha" - calibration_dir: str = ".cache/calibration/aloha_default" - - # /!\ FOR SAFETY, READ THIS /!\ - # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as - # the number of motors in your follower arms. - # For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default. - # When you feel more confident with teleoperation or running the policy, you can extend - # this safety limit and even removing it by setting it to `null`. - # Also, everything is expected to work safely out-of-the-box, but we highly advise to - # first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml), - # then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully - max_relative_target: int | None = 5 - - leader_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "left": DynamixelMotorsBusConfig( - # window_x - port="/dev/ttyDXL_leader_left", - motors={ - # name: (index, model) - "waist": [1, "xm430-w350"], - "shoulder": [2, "xm430-w350"], - "shoulder_shadow": [3, "xm430-w350"], - "elbow": [4, "xm430-w350"], - "elbow_shadow": [5, "xm430-w350"], - "forearm_roll": [6, "xm430-w350"], - "wrist_angle": [7, "xm430-w350"], - "wrist_rotate": [8, "xl430-w250"], - "gripper": [9, "xc430-w150"], - }, - ), - "right": DynamixelMotorsBusConfig( - # window_x - port="/dev/ttyDXL_leader_right", - motors={ - # name: (index, model) - "waist": [1, "xm430-w350"], - "shoulder": [2, "xm430-w350"], - "shoulder_shadow": [3, "xm430-w350"], - "elbow": [4, "xm430-w350"], - "elbow_shadow": [5, "xm430-w350"], - "forearm_roll": [6, "xm430-w350"], - "wrist_angle": [7, "xm430-w350"], - "wrist_rotate": [8, "xl430-w250"], - "gripper": [9, "xc430-w150"], - }, - ), - } - ) - - follower_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "left": DynamixelMotorsBusConfig( - port="/dev/ttyDXL_follower_left", - motors={ - # name: (index, model) - "waist": [1, "xm540-w270"], - "shoulder": [2, "xm540-w270"], - "shoulder_shadow": [3, "xm540-w270"], - "elbow": [4, "xm540-w270"], - "elbow_shadow": [5, "xm540-w270"], - "forearm_roll": [6, "xm540-w270"], - "wrist_angle": [7, "xm540-w270"], - "wrist_rotate": [8, "xm430-w350"], - "gripper": [9, "xm430-w350"], - }, - ), - "right": DynamixelMotorsBusConfig( - port="/dev/ttyDXL_follower_right", - motors={ - # name: (index, model) - "waist": [1, "xm540-w270"], - "shoulder": [2, "xm540-w270"], - "shoulder_shadow": [3, "xm540-w270"], - "elbow": [4, "xm540-w270"], - "elbow_shadow": [5, "xm540-w270"], - "forearm_roll": [6, "xm540-w270"], - "wrist_angle": [7, "xm540-w270"], - "wrist_rotate": [8, "xm430-w350"], - "gripper": [9, "xm430-w350"], - }, - ), - } - ) - - # Troubleshooting: If one of your IntelRealSense cameras freeze during - # data recording due to bandwidth limit, you might need to plug the camera - # on another USB hub or PCIe card. - cameras: dict[str, CameraConfig] = field( - default_factory=lambda: { - "cam_high": IntelRealSenseCameraConfig( - serial_number=128422271347, - fps=30, - width=640, - height=480, - ), - "cam_low": IntelRealSenseCameraConfig( - serial_number=130322270656, - fps=30, - width=640, - height=480, - ), - "cam_left_wrist": IntelRealSenseCameraConfig( - serial_number=218622272670, - fps=30, - width=640, - height=480, - ), - "cam_right_wrist": IntelRealSenseCameraConfig( - serial_number=130322272300, - fps=30, - width=640, - height=480, - ), - } - ) - - mock: bool = False - - -@RobotConfig.register_subclass("koch") -@dataclass -class KochRobotConfig(ManipulatorRobotConfig): - calibration_dir: str = ".cache/calibration/koch" - # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as - # the number of motors in your follower arms. - max_relative_target: int | None = None - - leader_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "main": DynamixelMotorsBusConfig( - port="/dev/tty.usbmodem585A0085511", - motors={ - # name: (index, model) - "shoulder_pan": [1, "xl330-m077"], - "shoulder_lift": [2, "xl330-m077"], - "elbow_flex": [3, "xl330-m077"], - "wrist_flex": [4, "xl330-m077"], - "wrist_roll": [5, "xl330-m077"], - "gripper": [6, "xl330-m077"], - }, - ), - } - ) - - follower_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "main": DynamixelMotorsBusConfig( - port="/dev/tty.usbmodem585A0076891", - motors={ - # name: (index, model) - "shoulder_pan": [1, "xl430-w250"], - "shoulder_lift": [2, "xl430-w250"], - "elbow_flex": [3, "xl330-m288"], - "wrist_flex": [4, "xl330-m288"], - "wrist_roll": [5, "xl330-m288"], - "gripper": [6, "xl330-m288"], - }, - ), - } - ) - - cameras: dict[str, CameraConfig] = field( - default_factory=lambda: { - "laptop": OpenCVCameraConfig( - camera_index=0, - fps=30, - width=640, - height=480, - ), - "phone": OpenCVCameraConfig( - camera_index=1, - fps=30, - width=640, - height=480, - ), - } - ) - - # ~ Koch specific settings ~ - # Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible - # to squeeze the gripper and have it spring back to an open position on its own. - gripper_open_degree: float = 35.156 - - mock: bool = False - - -@RobotConfig.register_subclass("koch_bimanual") -@dataclass -class KochBimanualRobotConfig(ManipulatorRobotConfig): - calibration_dir: str = ".cache/calibration/koch_bimanual" - # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as - # the number of motors in your follower arms. - max_relative_target: int | None = None - - leader_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "left": DynamixelMotorsBusConfig( - port="/dev/tty.usbmodem585A0085511", - motors={ - # name: (index, model) - "shoulder_pan": [1, "xl330-m077"], - "shoulder_lift": [2, "xl330-m077"], - "elbow_flex": [3, "xl330-m077"], - "wrist_flex": [4, "xl330-m077"], - "wrist_roll": [5, "xl330-m077"], - "gripper": [6, "xl330-m077"], - }, - ), - "right": DynamixelMotorsBusConfig( - port="/dev/tty.usbmodem575E0031751", - motors={ - # name: (index, model) - "shoulder_pan": [1, "xl330-m077"], - "shoulder_lift": [2, "xl330-m077"], - "elbow_flex": [3, "xl330-m077"], - "wrist_flex": [4, "xl330-m077"], - "wrist_roll": [5, "xl330-m077"], - "gripper": [6, "xl330-m077"], - }, - ), - } - ) - - follower_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "left": DynamixelMotorsBusConfig( - port="/dev/tty.usbmodem585A0076891", - motors={ - # name: (index, model) - "shoulder_pan": [1, "xl430-w250"], - "shoulder_lift": [2, "xl430-w250"], - "elbow_flex": [3, "xl330-m288"], - "wrist_flex": [4, "xl330-m288"], - "wrist_roll": [5, "xl330-m288"], - "gripper": [6, "xl330-m288"], - }, - ), - "right": DynamixelMotorsBusConfig( - port="/dev/tty.usbmodem575E0032081", - motors={ - # name: (index, model) - "shoulder_pan": [1, "xl430-w250"], - "shoulder_lift": [2, "xl430-w250"], - "elbow_flex": [3, "xl330-m288"], - "wrist_flex": [4, "xl330-m288"], - "wrist_roll": [5, "xl330-m288"], - "gripper": [6, "xl330-m288"], - }, - ), - } - ) - - cameras: dict[str, CameraConfig] = field( - default_factory=lambda: { - "laptop": OpenCVCameraConfig( - camera_index=0, - fps=30, - width=640, - height=480, - ), - "phone": OpenCVCameraConfig( - camera_index=1, - fps=30, - width=640, - height=480, - ), - } - ) - - # ~ Koch specific settings ~ - # Sets the leader arm in torque mode with the gripper motor set to this angle. This makes it possible - # to squeeze the gripper and have it spring back to an open position on its own. - gripper_open_degree: float = 35.156 - - mock: bool = False - - -@RobotConfig.register_subclass("moss") -@dataclass -class MossRobotConfig(ManipulatorRobotConfig): - calibration_dir: str = ".cache/calibration/moss" - # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as - # the number of motors in your follower arms. - max_relative_target: int | None = None - - leader_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "main": FeetechMotorsBusConfig( - port="/dev/tty.usbmodem58760431091", - motors={ - # name: (index, model) - "shoulder_pan": [1, "sts3215"], - "shoulder_lift": [2, "sts3215"], - "elbow_flex": [3, "sts3215"], - "wrist_flex": [4, "sts3215"], - "wrist_roll": [5, "sts3215"], - "gripper": [6, "sts3215"], - }, - ), - } - ) - - follower_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "main": FeetechMotorsBusConfig( - port="/dev/tty.usbmodem585A0076891", - motors={ - # name: (index, model) - "shoulder_pan": [1, "sts3215"], - "shoulder_lift": [2, "sts3215"], - "elbow_flex": [3, "sts3215"], - "wrist_flex": [4, "sts3215"], - "wrist_roll": [5, "sts3215"], - "gripper": [6, "sts3215"], - }, - ), - } - ) - - cameras: dict[str, CameraConfig] = field( - default_factory=lambda: { - "laptop": OpenCVCameraConfig( - camera_index=0, - fps=30, - width=640, - height=480, - ), - "phone": OpenCVCameraConfig( - camera_index=1, - fps=30, - width=640, - height=480, - ), - } - ) - - mock: bool = False - - -@RobotConfig.register_subclass("so100") -@dataclass -class So100RobotConfig(ManipulatorRobotConfig): - calibration_dir: str = ".cache/calibration/so100" - # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as - # the number of motors in your follower arms. - max_relative_target: int | None = None - - leader_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "main": FeetechMotorsBusConfig( - port="/dev/tty.usbmodem58760431091", - motors={ - # name: (index, model) - "shoulder_pan": [1, "sts3215"], - "shoulder_lift": [2, "sts3215"], - "elbow_flex": [3, "sts3215"], - "wrist_flex": [4, "sts3215"], - "wrist_roll": [5, "sts3215"], - "gripper": [6, "sts3215"], - }, - ), - } - ) - - follower_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "main": FeetechMotorsBusConfig( - port="/dev/tty.usbmodem585A0076891", - motors={ - # name: (index, model) - "shoulder_pan": [1, "sts3215"], - "shoulder_lift": [2, "sts3215"], - "elbow_flex": [3, "sts3215"], - "wrist_flex": [4, "sts3215"], - "wrist_roll": [5, "sts3215"], - "gripper": [6, "sts3215"], - }, - ), - } - ) - - cameras: dict[str, CameraConfig] = field( - default_factory=lambda: { - "laptop": OpenCVCameraConfig( - camera_index=0, - fps=30, - width=640, - height=480, - ), - "phone": OpenCVCameraConfig( - camera_index=1, - fps=30, - width=640, - height=480, - ), - } - ) - - mock: bool = False - - -@RobotConfig.register_subclass("stretch") -@dataclass -class StretchRobotConfig(RobotConfig): - # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as - # the number of motors in your follower arms. - max_relative_target: int | None = None - - cameras: dict[str, CameraConfig] = field( - default_factory=lambda: { - "navigation": OpenCVCameraConfig( - camera_index="/dev/hello-nav-head-camera", - fps=10, - width=1280, - height=720, - rotation=-90, - ), - "head": IntelRealSenseCameraConfig( - name="Intel RealSense D435I", - fps=30, - width=640, - height=480, - rotation=90, - ), - "wrist": IntelRealSenseCameraConfig( - name="Intel RealSense D405", - fps=30, - width=640, - height=480, - ), - } - ) - - mock: bool = False - - -@RobotConfig.register_subclass("lekiwi") -@dataclass -class LeKiwiRobotConfig(RobotConfig): - # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. - # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as - # the number of motors in your follower arms. - max_relative_target: int | None = None - - # Network Configuration - ip: str = "192.168.0.193" - port: int = 5555 - video_port: int = 5556 - - cameras: dict[str, CameraConfig] = field( - default_factory=lambda: { - "front": OpenCVCameraConfig( - camera_index="/dev/video0", fps=30, width=640, height=480, rotation=90 - ), - "wrist": OpenCVCameraConfig( - camera_index="/dev/video2", fps=30, width=640, height=480, rotation=180 - ), - } - ) - - calibration_dir: str = ".cache/calibration/lekiwi" - - leader_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "main": FeetechMotorsBusConfig( - port="/dev/tty.usbmodem585A0077581", - motors={ - # name: (index, model) - "shoulder_pan": [1, "sts3215"], - "shoulder_lift": [2, "sts3215"], - "elbow_flex": [3, "sts3215"], - "wrist_flex": [4, "sts3215"], - "wrist_roll": [5, "sts3215"], - "gripper": [6, "sts3215"], - }, - ), - } - ) - - follower_arms: dict[str, MotorsBusConfig] = field( - default_factory=lambda: { - "main": FeetechMotorsBusConfig( - port="/dev/ttyACM0", - motors={ - # name: (index, model) - "shoulder_pan": [1, "sts3215"], - "shoulder_lift": [2, "sts3215"], - "elbow_flex": [3, "sts3215"], - "wrist_flex": [4, "sts3215"], - "wrist_roll": [5, "sts3215"], - "gripper": [6, "sts3215"], - "left_wheel": (7, "sts3215"), - "back_wheel": (8, "sts3215"), - "right_wheel": (9, "sts3215"), - }, - ), - } - ) - - teleop_keys: dict[str, str] = field( - default_factory=lambda: { - # Movement - "forward": "w", - "backward": "s", - "left": "a", - "right": "d", - "rotate_left": "z", - "rotate_right": "x", - # Speed control - "speed_up": "r", - "speed_down": "f", - # quit teleop - "quit": "q", - } - ) - - mock: bool = False diff --git a/lerobot/common/robot_devices/robots/dynamixel_calibration.py b/lerobot/common/robot_devices/robots/dynamixel_calibration.py deleted file mode 100644 index 98fe8754f..000000000 --- a/lerobot/common/robot_devices/robots/dynamixel_calibration.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Logic to calibrate a robot arm built with dynamixel motors""" -# TODO(rcadene, aliberts): move this logic into the robot code when refactoring - -import numpy as np - -from lerobot.common.robot_devices.motors.dynamixel import ( - CalibrationMode, - TorqueMode, - convert_degrees_to_steps, -) -from lerobot.common.robot_devices.motors.utils import MotorsBus - -URL_TEMPLATE = ( - "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp" -) - -# The following positions are provided in nominal degree range ]-180, +180[ -# For more info on these constants, see comments in the code where they get used. -ZERO_POSITION_DEGREE = 0 -ROTATED_POSITION_DEGREE = 90 - - -def assert_drive_mode(drive_mode): - # `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted. - if not np.all(np.isin(drive_mode, [0, 1])): - raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})") - - -def apply_drive_mode(position, drive_mode): - assert_drive_mode(drive_mode) - # Convert `drive_mode` from [0, 1] with 0 indicates original rotation direction and 1 inverted, - # to [-1, 1] with 1 indicates original rotation direction and -1 inverted. - signed_drive_mode = -(drive_mode * 2 - 1) - position *= signed_drive_mode - return position - - -def compute_nearest_rounded_position(position, models): - delta_turn = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, models) - nearest_pos = np.round(position.astype(float) / delta_turn) * delta_turn - return nearest_pos.astype(position.dtype) - - -def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): - """This function ensures that a neural network trained on data collected on a given robot - can work on another robot. For instance before calibration, setting a same goal position - for each motor of two different robots will get two very different positions. But after calibration, - the two robots will move to the same position.To this end, this function computes the homing offset - and the drive mode for each motor of a given robot. - - Homing offset is used to shift the motor position to a ]-2048, +2048[ nominal range (when the motor uses 2048 steps - to complete a half a turn). This range is set around an arbitrary "zero position" corresponding to all motor positions - being 0. During the calibration process, you will need to manually move the robot to this "zero position". - - Drive mode is used to invert the rotation direction of the motor. This is useful when some motors have been assembled - in the opposite orientation for some robots. During the calibration process, you will need to manually move the robot - to the "rotated position". - - After calibration, the homing offsets and drive modes are stored in a cache. - - Example of usage: - ```python - run_arm_calibration(arm, "koch", "left", "follower") - ``` - """ - if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): - raise ValueError("To run calibration, the torque must be disabled on all motors.") - - print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") - - print("\nMove arm to zero position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")) - input("Press Enter to continue...") - - # We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed. - # It is easy to identify and all motors are in a "quarter turn" position. Once calibration is done, this position will - # correspond to every motor angle being 0. If you set all 0 as Goal Position, the arm will move in this position. - zero_target_pos = convert_degrees_to_steps(ZERO_POSITION_DEGREE, arm.motor_models) - - # Compute homing offset so that `present_position + homing_offset ~= target_position`. - zero_pos = arm.read("Present_Position") - zero_nearest_pos = compute_nearest_rounded_position(zero_pos, arm.motor_models) - homing_offset = zero_target_pos - zero_nearest_pos - - # The rotated target position corresponds to a rotation of a quarter turn from the zero position. - # This allows to identify the rotation direction of each motor. - # For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction - # is inverted. However, for the calibration being successful, we need everyone to follow the same target position. - # Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which - # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view - # of the previous motor in the kinetic chain. - print("\nMove arm to rotated target position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")) - input("Press Enter to continue...") - - rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models) - - # Find drive mode by rotating each motor by a quarter of a turn. - # Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0). - rotated_pos = arm.read("Present_Position") - drive_mode = (rotated_pos < zero_pos).astype(np.int32) - - # Re-compute homing offset to take into account drive mode - rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode) - rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models) - homing_offset = rotated_target_pos - rotated_nearest_pos - - print("\nMove arm to rest position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")) - input("Press Enter to continue...") - print() - - # Joints with rotational motions are expressed in degrees in nominal range of [-180, 180] - calib_mode = [CalibrationMode.DEGREE.name] * len(arm.motor_names) - - # TODO(rcadene): make type of joints (DEGREE or LINEAR) configurable from yaml? - if robot_type in ["aloha"] and "gripper" in arm.motor_names: - # Joints with linear motions (like gripper of Aloha) are expressed in nominal range of [0, 100] - calib_idx = arm.motor_names.index("gripper") - calib_mode[calib_idx] = CalibrationMode.LINEAR.name - - calib_data = { - "homing_offset": homing_offset.tolist(), - "drive_mode": drive_mode.tolist(), - "start_pos": zero_pos.tolist(), - "end_pos": rotated_pos.tolist(), - "calib_mode": calib_mode, - "motor_names": arm.motor_names, - } - return calib_data diff --git a/lerobot/common/robot_devices/robots/feetech_calibration.py b/lerobot/common/robot_devices/robots/feetech_calibration.py deleted file mode 100644 index 2c1e7180e..000000000 --- a/lerobot/common/robot_devices/robots/feetech_calibration.py +++ /dev/null @@ -1,498 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Logic to calibrate a robot arm built with feetech motors""" -# TODO(rcadene, aliberts): move this logic into the robot code when refactoring - -import time - -import numpy as np - -from lerobot.common.robot_devices.motors.feetech import ( - CalibrationMode, - TorqueMode, - convert_degrees_to_steps, -) -from lerobot.common.robot_devices.motors.utils import MotorsBus - -URL_TEMPLATE = ( - "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp" -) - -# The following positions are provided in nominal degree range ]-180, +180[ -# For more info on these constants, see comments in the code where they get used. -ZERO_POSITION_DEGREE = 0 -ROTATED_POSITION_DEGREE = 90 - - -def assert_drive_mode(drive_mode): - # `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted. - if not np.all(np.isin(drive_mode, [0, 1])): - raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})") - - -def apply_drive_mode(position, drive_mode): - assert_drive_mode(drive_mode) - # Convert `drive_mode` from [0, 1] with 0 indicates original rotation direction and 1 inverted, - # to [-1, 1] with 1 indicates original rotation direction and -1 inverted. - signed_drive_mode = -(drive_mode * 2 - 1) - position *= signed_drive_mode - return position - - -def move_until_block(arm, motor_name, positive_direction=True, while_move_hook=None): - count = 0 - while True: - present_pos = arm.read("Present_Position", motor_name) - if positive_direction: - # Move +100 steps every time. Lower the steps to lower the speed at which the arm moves. - arm.write("Goal_Position", present_pos + 100, motor_name) - else: - arm.write("Goal_Position", present_pos - 100, motor_name) - - if while_move_hook is not None: - while_move_hook() - - present_pos = arm.read("Present_Position", motor_name).item() - present_speed = arm.read("Present_Speed", motor_name).item() - present_current = arm.read("Present_Current", motor_name).item() - # present_load = arm.read("Present_Load", motor_name).item() - # present_voltage = arm.read("Present_Voltage", motor_name).item() - # present_temperature = arm.read("Present_Temperature", motor_name).item() - - # print(f"{present_pos=}") - # print(f"{present_speed=}") - # print(f"{present_current=}") - # print(f"{present_load=}") - # print(f"{present_voltage=}") - # print(f"{present_temperature=}") - - if present_speed == 0 and present_current > 40: - count += 1 - if count > 100 or present_current > 300: - return present_pos - else: - count = 0 - - -def move_to_calibrate( - arm, - motor_name, - invert_drive_mode=False, - positive_first=True, - in_between_move_hook=None, - while_move_hook=None, -): - initial_pos = arm.read("Present_Position", motor_name) - - if positive_first: - p_present_pos = move_until_block( - arm, motor_name, positive_direction=True, while_move_hook=while_move_hook - ) - else: - n_present_pos = move_until_block( - arm, motor_name, positive_direction=False, while_move_hook=while_move_hook - ) - - if in_between_move_hook is not None: - in_between_move_hook() - - if positive_first: - n_present_pos = move_until_block( - arm, motor_name, positive_direction=False, while_move_hook=while_move_hook - ) - else: - p_present_pos = move_until_block( - arm, motor_name, positive_direction=True, while_move_hook=while_move_hook - ) - - zero_pos = (n_present_pos + p_present_pos) / 2 - - calib_data = { - "initial_pos": initial_pos, - "homing_offset": zero_pos if invert_drive_mode else -zero_pos, - "invert_drive_mode": invert_drive_mode, - "drive_mode": -1 if invert_drive_mode else 0, - "zero_pos": zero_pos, - "start_pos": n_present_pos if invert_drive_mode else p_present_pos, - "end_pos": p_present_pos if invert_drive_mode else n_present_pos, - } - return calib_data - - -def apply_offset(calib, offset): - calib["zero_pos"] += offset - if calib["drive_mode"]: - calib["homing_offset"] += offset - else: - calib["homing_offset"] -= offset - return calib - - -def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): - if robot_type == "so100": - return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type) - elif robot_type == "moss": - return run_arm_auto_calibration_moss(arm, robot_type, arm_name, arm_type) - else: - raise ValueError(robot_type) - - -def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): - """All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms""" - if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): - raise ValueError("To run calibration, the torque must be disabled on all motors.") - - if not (robot_type == "so100" and arm_type == "follower"): - raise NotImplementedError("Auto calibration only supports the follower of so100 arms for now.") - - print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") - - print("\nMove arm to initial position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")) - input("Press Enter to continue...") - - # Lower the acceleration of the motors (in [0,254]) - initial_acceleration = arm.read("Acceleration") - arm.write("Lock", 0) - arm.write("Acceleration", 10) - time.sleep(1) - - arm.write("Torque_Enable", TorqueMode.ENABLED.value) - - print(f'{arm.read("Present_Position", "elbow_flex")=}') - - calib = {} - - init_wf_pos = arm.read("Present_Position", "wrist_flex") - init_sl_pos = arm.read("Present_Position", "shoulder_lift") - init_ef_pos = arm.read("Present_Position", "elbow_flex") - arm.write("Goal_Position", init_wf_pos - 800, "wrist_flex") - arm.write("Goal_Position", init_sl_pos + 150 + 1024, "shoulder_lift") - arm.write("Goal_Position", init_ef_pos - 2048, "elbow_flex") - time.sleep(2) - - print("Calibrate shoulder_pan") - calib["shoulder_pan"] = move_to_calibrate(arm, "shoulder_pan") - arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan") - time.sleep(1) - - print("Calibrate gripper") - calib["gripper"] = move_to_calibrate(arm, "gripper", invert_drive_mode=True) - time.sleep(1) - - print("Calibrate wrist_flex") - calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex") - calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=80) - - def in_between_move_hook(): - nonlocal arm, calib - time.sleep(2) - ef_pos = arm.read("Present_Position", "elbow_flex") - sl_pos = arm.read("Present_Position", "shoulder_lift") - arm.write("Goal_Position", ef_pos + 1024, "elbow_flex") - arm.write("Goal_Position", sl_pos - 1024, "shoulder_lift") - time.sleep(2) - - print("Calibrate elbow_flex") - calib["elbow_flex"] = move_to_calibrate( - arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook - ) - calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024) - - arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex") - time.sleep(1) - - def in_between_move_hook(): - nonlocal arm, calib - arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"], "elbow_flex") - - print("Calibrate shoulder_lift") - calib["shoulder_lift"] = move_to_calibrate( - arm, - "shoulder_lift", - invert_drive_mode=True, - positive_first=False, - in_between_move_hook=in_between_move_hook, - ) - # add an 30 steps as offset to align with body - calib["shoulder_lift"] = apply_offset(calib["shoulder_lift"], offset=1024 - 50) - - def while_move_hook(): - nonlocal arm, calib - positions = { - "shoulder_lift": round(calib["shoulder_lift"]["zero_pos"] - 1600), - "elbow_flex": round(calib["elbow_flex"]["zero_pos"] + 1700), - "wrist_flex": round(calib["wrist_flex"]["zero_pos"] + 800), - "gripper": round(calib["gripper"]["end_pos"]), - } - arm.write("Goal_Position", list(positions.values()), list(positions.keys())) - - arm.write("Goal_Position", round(calib["shoulder_lift"]["zero_pos"] - 1600), "shoulder_lift") - time.sleep(2) - arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex") - time.sleep(2) - arm.write("Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex") - time.sleep(2) - arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper") - time.sleep(2) - - print("Calibrate wrist_roll") - calib["wrist_roll"] = move_to_calibrate( - arm, "wrist_roll", invert_drive_mode=True, positive_first=False, while_move_hook=while_move_hook - ) - - arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll") - time.sleep(1) - arm.write("Goal_Position", calib["gripper"]["start_pos"], "gripper") - time.sleep(1) - arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex") - time.sleep(1) - arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex") - arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift") - time.sleep(1) - arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan") - time.sleep(1) - - calib_modes = [] - for name in arm.motor_names: - if name == "gripper": - calib_modes.append(CalibrationMode.LINEAR.name) - else: - calib_modes.append(CalibrationMode.DEGREE.name) - - calib_dict = { - "homing_offset": [calib[name]["homing_offset"] for name in arm.motor_names], - "drive_mode": [calib[name]["drive_mode"] for name in arm.motor_names], - "start_pos": [calib[name]["start_pos"] for name in arm.motor_names], - "end_pos": [calib[name]["end_pos"] for name in arm.motor_names], - "calib_mode": calib_modes, - "motor_names": arm.motor_names, - } - - # Re-enable original accerlation - arm.write("Lock", 0) - arm.write("Acceleration", initial_acceleration) - time.sleep(1) - - return calib_dict - - -def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): - """All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms""" - if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): - raise ValueError("To run calibration, the torque must be disabled on all motors.") - - if not (robot_type == "moss" and arm_type == "follower"): - raise NotImplementedError("Auto calibration only supports the follower of moss arms for now.") - - print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") - - print("\nMove arm to initial position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")) - input("Press Enter to continue...") - - # Lower the acceleration of the motors (in [0,254]) - initial_acceleration = arm.read("Acceleration") - arm.write("Lock", 0) - arm.write("Acceleration", 10) - time.sleep(1) - - arm.write("Torque_Enable", TorqueMode.ENABLED.value) - - sl_pos = arm.read("Present_Position", "shoulder_lift") - arm.write("Goal_Position", sl_pos - 1024 - 450, "shoulder_lift") - ef_pos = arm.read("Present_Position", "elbow_flex") - arm.write("Goal_Position", ef_pos + 1024 + 450, "elbow_flex") - time.sleep(2) - - calib = {} - - print("Calibrate shoulder_pan") - calib["shoulder_pan"] = move_to_calibrate(arm, "shoulder_pan") - arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan") - time.sleep(1) - - print("Calibrate gripper") - calib["gripper"] = move_to_calibrate(arm, "gripper", invert_drive_mode=True) - time.sleep(1) - - print("Calibrate wrist_flex") - calib["wrist_flex"] = move_to_calibrate(arm, "wrist_flex", invert_drive_mode=True) - calib["wrist_flex"] = apply_offset(calib["wrist_flex"], offset=-210 + 1024) - - wr_pos = arm.read("Present_Position", "wrist_roll") - arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex") - time.sleep(1) - arm.write("Goal_Position", wr_pos - 1024, "wrist_roll") - time.sleep(1) - arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 2048, "wrist_flex") - time.sleep(1) - arm.write("Goal_Position", calib["gripper"]["end_pos"], "gripper") - time.sleep(1) - - print("Calibrate wrist_roll") - calib["wrist_roll"] = move_to_calibrate(arm, "wrist_roll", invert_drive_mode=True) - calib["wrist_roll"] = apply_offset(calib["wrist_roll"], offset=790) - - arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"] - 1024, "wrist_roll") - arm.write("Goal_Position", calib["gripper"]["start_pos"], "gripper") - arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex") - time.sleep(1) - arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll") - arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 2048, "wrist_flex") - - def in_between_move_elbow_flex_hook(): - nonlocal arm, calib - arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex") - - print("Calibrate elbow_flex") - calib["elbow_flex"] = move_to_calibrate( - arm, - "elbow_flex", - invert_drive_mode=True, - in_between_move_hook=in_between_move_elbow_flex_hook, - ) - arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex") - - def in_between_move_shoulder_lift_hook(): - nonlocal arm, calib - sl = arm.read("Present_Position", "shoulder_lift") - arm.write("Goal_Position", sl - 1500, "shoulder_lift") - time.sleep(1) - arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1536, "elbow_flex") - time.sleep(1) - arm.write("Goal_Position", calib["wrist_flex"]["start_pos"], "wrist_flex") - time.sleep(1) - - print("Calibrate shoulder_lift") - calib["shoulder_lift"] = move_to_calibrate( - arm, "shoulder_lift", in_between_move_hook=in_between_move_shoulder_lift_hook - ) - calib["shoulder_lift"] = apply_offset(calib["shoulder_lift"], offset=-1024) - - arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex") - time.sleep(1) - arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift") - arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex") - time.sleep(2) - - calib_modes = [] - for name in arm.motor_names: - if name == "gripper": - calib_modes.append(CalibrationMode.LINEAR.name) - else: - calib_modes.append(CalibrationMode.DEGREE.name) - - calib_dict = { - "homing_offset": [calib[name]["homing_offset"] for name in arm.motor_names], - "drive_mode": [calib[name]["drive_mode"] for name in arm.motor_names], - "start_pos": [calib[name]["start_pos"] for name in arm.motor_names], - "end_pos": [calib[name]["end_pos"] for name in arm.motor_names], - "calib_mode": calib_modes, - "motor_names": arm.motor_names, - } - - # Re-enable original accerlation - arm.write("Lock", 0) - arm.write("Acceleration", initial_acceleration) - time.sleep(1) - - return calib_dict - - -def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): - """This function ensures that a neural network trained on data collected on a given robot - can work on another robot. For instance before calibration, setting a same goal position - for each motor of two different robots will get two very different positions. But after calibration, - the two robots will move to the same position.To this end, this function computes the homing offset - and the drive mode for each motor of a given robot. - - Homing offset is used to shift the motor position to a ]-2048, +2048[ nominal range (when the motor uses 2048 steps - to complete a half a turn). This range is set around an arbitrary "zero position" corresponding to all motor positions - being 0. During the calibration process, you will need to manually move the robot to this "zero position". - - Drive mode is used to invert the rotation direction of the motor. This is useful when some motors have been assembled - in the opposite orientation for some robots. During the calibration process, you will need to manually move the robot - to the "rotated position". - - After calibration, the homing offsets and drive modes are stored in a cache. - - Example of usage: - ```python - run_arm_calibration(arm, "so100", "left", "follower") - ``` - """ - if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): - raise ValueError("To run calibration, the torque must be disabled on all motors.") - - print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") - - print("\nMove arm to zero position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")) - input("Press Enter to continue...") - - # We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed. - # It is easy to identify and all motors are in a "quarter turn" position. Once calibration is done, this position will - # correspond to every motor angle being 0. If you set all 0 as Goal Position, the arm will move in this position. - zero_target_pos = convert_degrees_to_steps(ZERO_POSITION_DEGREE, arm.motor_models) - - # Compute homing offset so that `present_position + homing_offset ~= target_position`. - zero_pos = arm.read("Present_Position") - homing_offset = zero_target_pos - zero_pos - - # The rotated target position corresponds to a rotation of a quarter turn from the zero position. - # This allows to identify the rotation direction of each motor. - # For instance, if the motor rotates 90 degree, and its value is -90 after applying the homing offset, then we know its rotation direction - # is inverted. However, for the calibration being successful, we need everyone to follow the same target position. - # Sometimes, there is only one possible rotation direction. For instance, if the gripper is closed, there is only one direction which - # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarily rotate clockwise from the point of view - # of the previous motor in the kinetic chain. - print("\nMove arm to rotated target position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")) - input("Press Enter to continue...") - - rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models) - - # Find drive mode by rotating each motor by a quarter of a turn. - # Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0). - rotated_pos = arm.read("Present_Position") - drive_mode = (rotated_pos < zero_pos).astype(np.int32) - - # Re-compute homing offset to take into account drive mode - rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode) - homing_offset = rotated_target_pos - rotated_drived_pos - - print("\nMove arm to rest position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")) - input("Press Enter to continue...") - print() - - # Joints with rotational motions are expressed in degrees in nominal range of [-180, 180] - calib_modes = [] - for name in arm.motor_names: - if name == "gripper": - calib_modes.append(CalibrationMode.LINEAR.name) - else: - calib_modes.append(CalibrationMode.DEGREE.name) - - calib_dict = { - "homing_offset": homing_offset.tolist(), - "drive_mode": drive_mode.tolist(), - "start_pos": zero_pos.tolist(), - "end_pos": rotated_pos.tolist(), - "calib_mode": calib_modes, - "motor_names": arm.motor_names, - } - return calib_dict diff --git a/lerobot/common/robot_devices/robots/lekiwi_remote.py b/lerobot/common/robot_devices/robots/lekiwi_remote.py deleted file mode 100644 index 7bf52d21d..000000000 --- a/lerobot/common/robot_devices/robots/lekiwi_remote.py +++ /dev/null @@ -1,224 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import base64 -import json -import threading -import time -from pathlib import Path - -import cv2 -import zmq - -from lerobot.common.robot_devices.robots.mobile_manipulator import LeKiwi - - -def setup_zmq_sockets(config): - context = zmq.Context() - cmd_socket = context.socket(zmq.PULL) - cmd_socket.setsockopt(zmq.CONFLATE, 1) - cmd_socket.bind(f"tcp://*:{config.port}") - - video_socket = context.socket(zmq.PUSH) - video_socket.setsockopt(zmq.CONFLATE, 1) - video_socket.bind(f"tcp://*:{config.video_port}") - - return context, cmd_socket, video_socket - - -def run_camera_capture(cameras, images_lock, latest_images_dict, stop_event): - while not stop_event.is_set(): - local_dict = {} - for name, cam in cameras.items(): - frame = cam.async_read() - ret, buffer = cv2.imencode(".jpg", frame, [int(cv2.IMWRITE_JPEG_QUALITY), 90]) - if ret: - local_dict[name] = base64.b64encode(buffer).decode("utf-8") - else: - local_dict[name] = "" - with images_lock: - latest_images_dict.update(local_dict) - time.sleep(0.01) - - -def calibrate_follower_arm(motors_bus, calib_dir_str): - """ - Calibrates the follower arm. Attempts to load an existing calibration file; - if not found, runs manual calibration and saves the result. - """ - calib_dir = Path(calib_dir_str) - calib_dir.mkdir(parents=True, exist_ok=True) - calib_file = calib_dir / "main_follower.json" - try: - from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration - except ImportError: - print("[WARNING] Calibration function not available. Skipping calibration.") - return - - if calib_file.exists(): - with open(calib_file) as f: - calibration = json.load(f) - print(f"[INFO] Loaded calibration from {calib_file}") - else: - print("[INFO] Calibration file not found. Running manual calibration...") - calibration = run_arm_manual_calibration(motors_bus, "lekiwi", "follower_arm", "follower") - print(f"[INFO] Calibration complete. Saving to {calib_file}") - with open(calib_file, "w") as f: - json.dump(calibration, f) - try: - motors_bus.set_calibration(calibration) - print("[INFO] Applied calibration for follower arm.") - except Exception as e: - print(f"[WARNING] Could not apply calibration: {e}") - - -def run_lekiwi(robot_config): - """ - Runs the LeKiwi robot: - - Sets up cameras and connects them. - - Initializes the follower arm motors. - - Calibrates the follower arm if necessary. - - Creates ZeroMQ sockets for receiving commands and streaming observations. - - Processes incoming commands (arm and wheel commands) and sends back sensor and camera data. - """ - # Import helper functions and classes - from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs - from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus, TorqueMode - - # Initialize cameras from the robot configuration. - cameras = make_cameras_from_configs(robot_config.cameras) - for cam in cameras.values(): - cam.connect() - - # Initialize the motors bus using the follower arm configuration. - motor_config = robot_config.follower_arms.get("main") - if motor_config is None: - print("[ERROR] Follower arm 'main' configuration not found.") - return - motors_bus = FeetechMotorsBus(motor_config) - motors_bus.connect() - - # Calibrate the follower arm. - calibrate_follower_arm(motors_bus, robot_config.calibration_dir) - - # Create the LeKiwi robot instance. - robot = LeKiwi(motors_bus) - - # Define the expected arm motor IDs. - arm_motor_ids = ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"] - - # Disable torque for each arm motor. - for motor in arm_motor_ids: - motors_bus.write("Torque_Enable", TorqueMode.DISABLED.value, motor) - - # Set up ZeroMQ sockets. - context, cmd_socket, video_socket = setup_zmq_sockets(robot_config) - - # Start the camera capture thread. - latest_images_dict = {} - images_lock = threading.Lock() - stop_event = threading.Event() - cam_thread = threading.Thread( - target=run_camera_capture, args=(cameras, images_lock, latest_images_dict, stop_event), daemon=True - ) - cam_thread.start() - - last_cmd_time = time.time() - print("LeKiwi robot server started. Waiting for commands...") - - try: - while True: - loop_start_time = time.time() - - # Process incoming commands (non-blocking). - while True: - try: - msg = cmd_socket.recv_string(zmq.NOBLOCK) - except zmq.Again: - break - try: - data = json.loads(msg) - # Process arm position commands. - if "arm_positions" in data: - arm_positions = data["arm_positions"] - if not isinstance(arm_positions, list): - print(f"[ERROR] Invalid arm_positions: {arm_positions}") - elif len(arm_positions) < len(arm_motor_ids): - print( - f"[WARNING] Received {len(arm_positions)} arm positions, expected {len(arm_motor_ids)}" - ) - else: - for motor, pos in zip(arm_motor_ids, arm_positions, strict=False): - motors_bus.write("Goal_Position", pos, motor) - # Process wheel (base) commands. - if "raw_velocity" in data: - raw_command = data["raw_velocity"] - # Expect keys: "left_wheel", "back_wheel", "right_wheel". - command_speeds = [ - int(raw_command.get("left_wheel", 0)), - int(raw_command.get("back_wheel", 0)), - int(raw_command.get("right_wheel", 0)), - ] - robot.set_velocity(command_speeds) - last_cmd_time = time.time() - except Exception as e: - print(f"[ERROR] Parsing message failed: {e}") - - # Watchdog: stop the robot if no command is received for over 0.5 seconds. - now = time.time() - if now - last_cmd_time > 0.5: - robot.stop() - last_cmd_time = now - - # Read current wheel speeds from the robot. - current_velocity = robot.read_velocity() - - # Read the follower arm state from the motors bus. - follower_arm_state = [] - for motor in arm_motor_ids: - try: - pos = motors_bus.read("Present_Position", motor) - # Convert the position to a float (or use as is if already numeric). - follower_arm_state.append(float(pos) if not isinstance(pos, (int, float)) else pos) - except Exception as e: - print(f"[ERROR] Reading motor {motor} failed: {e}") - - # Get the latest camera images. - with images_lock: - images_dict_copy = dict(latest_images_dict) - - # Build the observation dictionary. - observation = { - "images": images_dict_copy, - "present_speed": current_velocity, - "follower_arm_state": follower_arm_state, - } - # Send the observation over the video socket. - video_socket.send_string(json.dumps(observation)) - - # Ensure a short sleep to avoid overloading the CPU. - elapsed = time.time() - loop_start_time - time.sleep( - max(0.033 - elapsed, 0) - ) # If robot jitters increase the sleep and monitor cpu load with `top` in cmd - except KeyboardInterrupt: - print("Shutting down LeKiwi server.") - finally: - stop_event.set() - cam_thread.join() - robot.stop() - motors_bus.disconnect() - cmd_socket.close() - video_socket.close() - context.term() diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py deleted file mode 100644 index 9173abc62..000000000 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ /dev/null @@ -1,627 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Contains logic to instantiate a robot, read information from its motors and cameras, -and send orders to its motors. -""" -# TODO(rcadene, aliberts): reorganize the codebase into one file per robot, with the associated -# calibration procedure, to make it easy for people to add their own robot. - -import json -import logging -import time -import warnings -from pathlib import Path - -import numpy as np -import torch - -from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs -from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs -from lerobot.common.robot_devices.robots.configs import ManipulatorRobotConfig -from lerobot.common.robot_devices.robots.utils import get_arm_id -from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError - - -def ensure_safe_goal_position( - goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float] -): - # Cap relative action target magnitude for safety. - diff = goal_pos - present_pos - max_relative_target = torch.tensor(max_relative_target) - safe_diff = torch.minimum(diff, max_relative_target) - safe_diff = torch.maximum(safe_diff, -max_relative_target) - safe_goal_pos = present_pos + safe_diff - - if not torch.allclose(goal_pos, safe_goal_pos): - logging.warning( - "Relative goal position magnitude had to be clamped to be safe.\n" - f" requested relative goal position target: {diff}\n" - f" clamped relative goal position target: {safe_diff}" - ) - - return safe_goal_pos - - -class ManipulatorRobot: - # TODO(rcadene): Implement force feedback - """This class allows to control any manipulator robot of various number of motors. - - Non exhaustive list of robots: - - [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow expansion, developed - by Alexander Koch from [Tau Robotics](https://tau-robotics.com) - - [Koch v1.1](https://github.com/jess-moss/koch-v1-1) developed by Jess Moss - - [Aloha](https://www.trossenrobotics.com/aloha-kits) developed by Trossen Robotics - - Example of instantiation, a pre-defined robot config is required: - ```python - robot = ManipulatorRobot(KochRobotConfig()) - ``` - - Example of overwriting motors during instantiation: - ```python - # Defines how to communicate with the motors of the leader and follower arms - leader_arms = { - "main": DynamixelMotorsBusConfig( - port="/dev/tty.usbmodem575E0031751", - motors={ - # name: (index, model) - "shoulder_pan": (1, "xl330-m077"), - "shoulder_lift": (2, "xl330-m077"), - "elbow_flex": (3, "xl330-m077"), - "wrist_flex": (4, "xl330-m077"), - "wrist_roll": (5, "xl330-m077"), - "gripper": (6, "xl330-m077"), - }, - ), - } - follower_arms = { - "main": DynamixelMotorsBusConfig( - port="/dev/tty.usbmodem575E0032081", - motors={ - # name: (index, model) - "shoulder_pan": (1, "xl430-w250"), - "shoulder_lift": (2, "xl430-w250"), - "elbow_flex": (3, "xl330-m288"), - "wrist_flex": (4, "xl330-m288"), - "wrist_roll": (5, "xl330-m288"), - "gripper": (6, "xl330-m288"), - }, - ), - } - robot_config = KochRobotConfig(leader_arms=leader_arms, follower_arms=follower_arms) - robot = ManipulatorRobot(robot_config) - ``` - - Example of overwriting cameras during instantiation: - ```python - # Defines how to communicate with 2 cameras connected to the computer. - # Here, the webcam of the laptop and the phone (connected in USB to the laptop) - # can be reached respectively using the camera indices 0 and 1. These indices can be - # arbitrary. See the documentation of `OpenCVCamera` to find your own camera indices. - cameras = { - "laptop": OpenCVCamera(camera_index=0, fps=30, width=640, height=480), - "phone": OpenCVCamera(camera_index=1, fps=30, width=640, height=480), - } - robot = ManipulatorRobot(KochRobotConfig(cameras=cameras)) - ``` - - Once the robot is instantiated, connect motors buses and cameras if any (Required): - ```python - robot.connect() - ``` - - Example of highest frequency teleoperation, which doesn't require cameras: - ```python - while True: - robot.teleop_step() - ``` - - Example of highest frequency data collection from motors and cameras (if any): - ```python - while True: - observation, action = robot.teleop_step(record_data=True) - ``` - - Example of controlling the robot with a policy: - ```python - while True: - # Uses the follower arms and cameras to capture an observation - observation = robot.capture_observation() - - # Assumes a policy has been instantiated - with torch.inference_mode(): - action = policy.select_action(observation) - - # Orders the robot to move - robot.send_action(action) - ``` - - Example of disconnecting which is not mandatory since we disconnect when the object is deleted: - ```python - robot.disconnect() - ``` - """ - - def __init__( - self, - config: ManipulatorRobotConfig, - ): - self.config = config - self.robot_type = self.config.type - self.calibration_dir = Path(self.config.calibration_dir) - self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms) - self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms) - self.cameras = make_cameras_from_configs(self.config.cameras) - self.is_connected = False - self.logs = {} - - def get_motor_names(self, arm: dict[str, MotorsBus]) -> list: - return [f"{arm}_{motor}" for arm, bus in arm.items() for motor in bus.motors] - - @property - def camera_features(self) -> dict: - cam_ft = {} - for cam_key, cam in self.cameras.items(): - key = f"observation.images.{cam_key}" - cam_ft[key] = { - "shape": (cam.height, cam.width, cam.channels), - "names": ["height", "width", "channels"], - "info": None, - } - return cam_ft - - @property - def motor_features(self) -> dict: - action_names = self.get_motor_names(self.leader_arms) - state_names = self.get_motor_names(self.leader_arms) - return { - "action": { - "dtype": "float32", - "shape": (len(action_names),), - "names": action_names, - }, - "observation.state": { - "dtype": "float32", - "shape": (len(state_names),), - "names": state_names, - }, - } - - @property - def features(self): - return {**self.motor_features, **self.camera_features} - - @property - def has_camera(self): - return len(self.cameras) > 0 - - @property - def num_cameras(self): - return len(self.cameras) - - @property - def available_arms(self): - available_arms = [] - for name in self.follower_arms: - arm_id = get_arm_id(name, "follower") - available_arms.append(arm_id) - for name in self.leader_arms: - arm_id = get_arm_id(name, "leader") - available_arms.append(arm_id) - return available_arms - - def connect(self): - if self.is_connected: - raise RobotDeviceAlreadyConnectedError( - "ManipulatorRobot is already connected. Do not run `robot.connect()` twice." - ) - - if not self.leader_arms and not self.follower_arms and not self.cameras: - raise ValueError( - "ManipulatorRobot doesn't have any device to connect. See example of usage in docstring of the class." - ) - - # Connect the arms - for name in self.follower_arms: - print(f"Connecting {name} follower arm.") - self.follower_arms[name].connect() - for name in self.leader_arms: - print(f"Connecting {name} leader arm.") - self.leader_arms[name].connect() - - if self.robot_type in ["koch", "koch_bimanual", "aloha"]: - from lerobot.common.robot_devices.motors.dynamixel import TorqueMode - elif self.robot_type in ["so100", "moss", "lekiwi"]: - from lerobot.common.robot_devices.motors.feetech import TorqueMode - - # We assume that at connection time, arms are in a rest position, and torque can - # be safely disabled to run calibration and/or set robot preset configurations. - for name in self.follower_arms: - self.follower_arms[name].write("Torque_Enable", TorqueMode.DISABLED.value) - for name in self.leader_arms: - self.leader_arms[name].write("Torque_Enable", TorqueMode.DISABLED.value) - - self.activate_calibration() - - # Set robot preset (e.g. torque in leader gripper for Koch v1.1) - if self.robot_type in ["koch", "koch_bimanual"]: - self.set_koch_robot_preset() - elif self.robot_type == "aloha": - self.set_aloha_robot_preset() - elif self.robot_type in ["so100", "moss", "lekiwi"]: - self.set_so100_robot_preset() - - # Enable torque on all motors of the follower arms - for name in self.follower_arms: - print(f"Activating torque on {name} follower arm.") - self.follower_arms[name].write("Torque_Enable", 1) - - if self.config.gripper_open_degree is not None: - if self.robot_type not in ["koch", "koch_bimanual"]: - raise NotImplementedError( - f"{self.robot_type} does not support position AND current control in the handle, which is require to set the gripper open." - ) - # Set the leader arm in torque mode with the gripper motor set to an angle. This makes it possible - # to squeeze the gripper and have it spring back to an open position on its own. - for name in self.leader_arms: - self.leader_arms[name].write("Torque_Enable", 1, "gripper") - self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper") - - # Check both arms can be read - for name in self.follower_arms: - self.follower_arms[name].read("Present_Position") - for name in self.leader_arms: - self.leader_arms[name].read("Present_Position") - - # Connect the cameras - for name in self.cameras: - self.cameras[name].connect() - - self.is_connected = True - - def activate_calibration(self): - """After calibration all motors function in human interpretable ranges. - Rotations are expressed in degrees in nominal range of [-180, 180], - and linear motions (like gripper of Aloha) in nominal range of [0, 100]. - """ - - def load_or_run_calibration_(name, arm, arm_type): - arm_id = get_arm_id(name, arm_type) - arm_calib_path = self.calibration_dir / f"{arm_id}.json" - - if arm_calib_path.exists(): - with open(arm_calib_path) as f: - calibration = json.load(f) - else: - # TODO(rcadene): display a warning in __init__ if calibration file not available - print(f"Missing calibration file '{arm_calib_path}'") - - if self.robot_type in ["koch", "koch_bimanual", "aloha"]: - from lerobot.common.robot_devices.robots.dynamixel_calibration import run_arm_calibration - - calibration = run_arm_calibration(arm, self.robot_type, name, arm_type) - - elif self.robot_type in ["so100", "moss", "lekiwi"]: - from lerobot.common.robot_devices.robots.feetech_calibration import ( - run_arm_manual_calibration, - ) - - calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type) - - print(f"Calibration is done! Saving calibration file '{arm_calib_path}'") - arm_calib_path.parent.mkdir(parents=True, exist_ok=True) - with open(arm_calib_path, "w") as f: - json.dump(calibration, f) - - return calibration - - for name, arm in self.follower_arms.items(): - calibration = load_or_run_calibration_(name, arm, "follower") - arm.set_calibration(calibration) - for name, arm in self.leader_arms.items(): - calibration = load_or_run_calibration_(name, arm, "leader") - arm.set_calibration(calibration) - - def set_koch_robot_preset(self): - def set_operating_mode_(arm): - from lerobot.common.robot_devices.motors.dynamixel import TorqueMode - - if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): - raise ValueError("To run set robot preset, the torque must be disabled on all motors.") - - # Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't - # rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm, - # you could end up with a servo with a position 0 or 4095 at a crucial point See [ - # https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11] - all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"] - if len(all_motors_except_gripper) > 0: - # 4 corresponds to Extended Position on Koch motors - arm.write("Operating_Mode", 4, all_motors_except_gripper) - - # Use 'position control current based' for gripper to be limited by the limit of the current. - # For the follower gripper, it means it can grasp an object without forcing too much even tho, - # it's goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch). - # For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger - # to make it move, and it will move back to its original target position when we release the force. - # 5 corresponds to Current Controlled Position on Koch gripper motors "xl330-m077, xl330-m288" - arm.write("Operating_Mode", 5, "gripper") - - for name in self.follower_arms: - set_operating_mode_(self.follower_arms[name]) - - # Set better PID values to close the gap between recorded states and actions - # TODO(rcadene): Implement an automatic procedure to set optimal PID values for each motor - self.follower_arms[name].write("Position_P_Gain", 1500, "elbow_flex") - self.follower_arms[name].write("Position_I_Gain", 0, "elbow_flex") - self.follower_arms[name].write("Position_D_Gain", 600, "elbow_flex") - - if self.config.gripper_open_degree is not None: - for name in self.leader_arms: - set_operating_mode_(self.leader_arms[name]) - - # Enable torque on the gripper of the leader arms, and move it to 45 degrees, - # so that we can use it as a trigger to close the gripper of the follower arms. - self.leader_arms[name].write("Torque_Enable", 1, "gripper") - self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper") - - def set_aloha_robot_preset(self): - def set_shadow_(arm): - # Set secondary/shadow ID for shoulder and elbow. These joints have two motors. - # As a result, if only one of them is required to move to a certain position, - # the other will follow. This is to avoid breaking the motors. - if "shoulder_shadow" in arm.motor_names: - shoulder_idx = arm.read("ID", "shoulder") - arm.write("Secondary_ID", shoulder_idx, "shoulder_shadow") - - if "elbow_shadow" in arm.motor_names: - elbow_idx = arm.read("ID", "elbow") - arm.write("Secondary_ID", elbow_idx, "elbow_shadow") - - for name in self.follower_arms: - set_shadow_(self.follower_arms[name]) - - for name in self.leader_arms: - set_shadow_(self.leader_arms[name]) - - for name in self.follower_arms: - # Set a velocity limit of 131 as advised by Trossen Robotics - self.follower_arms[name].write("Velocity_Limit", 131) - - # Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't - # rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm, - # you could end up with a servo with a position 0 or 4095 at a crucial point See [ - # https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11] - all_motors_except_gripper = [ - name for name in self.follower_arms[name].motor_names if name != "gripper" - ] - if len(all_motors_except_gripper) > 0: - # 4 corresponds to Extended Position on Aloha motors - self.follower_arms[name].write("Operating_Mode", 4, all_motors_except_gripper) - - # Use 'position control current based' for follower gripper to be limited by the limit of the current. - # It can grasp an object without forcing too much even tho, - # it's goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch). - # 5 corresponds to Current Controlled Position on Aloha gripper follower "xm430-w350" - self.follower_arms[name].write("Operating_Mode", 5, "gripper") - - # Note: We can't enable torque on the leader gripper since "xc430-w150" doesn't have - # a Current Controlled Position mode. - - if self.config.gripper_open_degree is not None: - warnings.warn( - f"`gripper_open_degree` is set to {self.config.gripper_open_degree}, but None is expected for Aloha instead", - stacklevel=1, - ) - - def set_so100_robot_preset(self): - for name in self.follower_arms: - # Mode=0 for Position Control - self.follower_arms[name].write("Mode", 0) - # Set P_Coefficient to lower value to avoid shakiness (Default is 32) - self.follower_arms[name].write("P_Coefficient", 16) - # Set I_Coefficient and D_Coefficient to default value 0 and 32 - self.follower_arms[name].write("I_Coefficient", 0) - self.follower_arms[name].write("D_Coefficient", 32) - # Close the write lock so that Maximum_Acceleration gets written to EPROM address, - # which is mandatory for Maximum_Acceleration to take effect after rebooting. - self.follower_arms[name].write("Lock", 0) - # Set Maximum_Acceleration to 254 to speedup acceleration and deceleration of - # the motors. Note: this configuration is not in the official STS3215 Memory Table - self.follower_arms[name].write("Maximum_Acceleration", 254) - self.follower_arms[name].write("Acceleration", 254) - - def teleop_step( - self, record_data=False - ) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: - if not self.is_connected: - raise RobotDeviceNotConnectedError( - "ManipulatorRobot is not connected. You need to run `robot.connect()`." - ) - - # Prepare to assign the position of the leader to the follower - leader_pos = {} - for name in self.leader_arms: - before_lread_t = time.perf_counter() - leader_pos[name] = self.leader_arms[name].read("Present_Position") - leader_pos[name] = torch.from_numpy(leader_pos[name]) - self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t - - # Send goal position to the follower - follower_goal_pos = {} - for name in self.follower_arms: - before_fwrite_t = time.perf_counter() - goal_pos = leader_pos[name] - - # Cap goal position when too far away from present position. - # Slower fps expected due to reading from the follower. - if self.config.max_relative_target is not None: - present_pos = self.follower_arms[name].read("Present_Position") - present_pos = torch.from_numpy(present_pos) - goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target) - - # Used when record_data=True - follower_goal_pos[name] = goal_pos - - goal_pos = goal_pos.numpy().astype(np.float32) - self.follower_arms[name].write("Goal_Position", goal_pos) - self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t - - # Early exit when recording data is not requested - if not record_data: - return - - # TODO(rcadene): Add velocity and other info - # Read follower position - follower_pos = {} - for name in self.follower_arms: - before_fread_t = time.perf_counter() - follower_pos[name] = self.follower_arms[name].read("Present_Position") - follower_pos[name] = torch.from_numpy(follower_pos[name]) - self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t - - # Create state by concatenating follower current position - state = [] - for name in self.follower_arms: - if name in follower_pos: - state.append(follower_pos[name]) - state = torch.cat(state) - - # Create action by concatenating follower goal position - action = [] - for name in self.follower_arms: - if name in follower_goal_pos: - action.append(follower_goal_pos[name]) - action = torch.cat(action) - - # Capture images from cameras - images = {} - for name in self.cameras: - before_camread_t = time.perf_counter() - images[name] = self.cameras[name].async_read() - images[name] = torch.from_numpy(images[name]) - self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] - self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t - - # Populate output dictionaries - obs_dict, action_dict = {}, {} - obs_dict["observation.state"] = state - action_dict["action"] = action - for name in self.cameras: - obs_dict[f"observation.images.{name}"] = images[name] - - return obs_dict, action_dict - - def capture_observation(self): - """The returned observations do not have a batch dimension.""" - if not self.is_connected: - raise RobotDeviceNotConnectedError( - "ManipulatorRobot is not connected. You need to run `robot.connect()`." - ) - - # Read follower position - follower_pos = {} - for name in self.follower_arms: - before_fread_t = time.perf_counter() - follower_pos[name] = self.follower_arms[name].read("Present_Position") - follower_pos[name] = torch.from_numpy(follower_pos[name]) - self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t - - # Create state by concatenating follower current position - state = [] - for name in self.follower_arms: - if name in follower_pos: - state.append(follower_pos[name]) - state = torch.cat(state) - - # Capture images from cameras - images = {} - for name in self.cameras: - before_camread_t = time.perf_counter() - images[name] = self.cameras[name].async_read() - images[name] = torch.from_numpy(images[name]) - self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] - self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t - - # Populate output dictionaries and format to pytorch - obs_dict = {} - obs_dict["observation.state"] = state - for name in self.cameras: - obs_dict[f"observation.images.{name}"] = images[name] - return obs_dict - - def send_action(self, action: torch.Tensor) -> torch.Tensor: - """Command the follower arms to move to a target joint configuration. - - The relative action magnitude may be clipped depending on the configuration parameter - `max_relative_target`. In this case, the action sent differs from original action. - Thus, this function always returns the action actually sent. - - Args: - action: tensor containing the concatenated goal positions for the follower arms. - """ - if not self.is_connected: - raise RobotDeviceNotConnectedError( - "ManipulatorRobot is not connected. You need to run `robot.connect()`." - ) - - from_idx = 0 - to_idx = 0 - action_sent = [] - for name in self.follower_arms: - # Get goal position of each follower arm by splitting the action vector - to_idx += len(self.follower_arms[name].motor_names) - goal_pos = action[from_idx:to_idx] - from_idx = to_idx - - # Cap goal position when too far away from present position. - # Slower fps expected due to reading from the follower. - if self.config.max_relative_target is not None: - present_pos = self.follower_arms[name].read("Present_Position") - present_pos = torch.from_numpy(present_pos) - goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target) - - # Save tensor to concat and return - action_sent.append(goal_pos) - - # Send goal position to each follower - goal_pos = goal_pos.numpy().astype(np.float32) - self.follower_arms[name].write("Goal_Position", goal_pos) - - return torch.cat(action_sent) - - def print_logs(self): - pass - # TODO(aliberts): move robot-specific logs logic here - - def disconnect(self): - if not self.is_connected: - raise RobotDeviceNotConnectedError( - "ManipulatorRobot is not connected. You need to run `robot.connect()` before disconnecting." - ) - - for name in self.follower_arms: - self.follower_arms[name].disconnect() - - for name in self.leader_arms: - self.leader_arms[name].disconnect() - - for name in self.cameras: - self.cameras[name].disconnect() - - self.is_connected = False - - def __del__(self): - if getattr(self, "is_connected", False): - self.disconnect() diff --git a/lerobot/common/robot_devices/robots/mobile_manipulator.py b/lerobot/common/robot_devices/robots/mobile_manipulator.py deleted file mode 100644 index 385e218be..000000000 --- a/lerobot/common/robot_devices/robots/mobile_manipulator.py +++ /dev/null @@ -1,703 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import base64 -import json -import os -import sys -from pathlib import Path - -import cv2 -import numpy as np -import torch -import zmq - -from lerobot.common.robot_devices.cameras.utils import make_cameras_from_configs -from lerobot.common.robot_devices.motors.feetech import TorqueMode -from lerobot.common.robot_devices.motors.utils import MotorsBus, make_motors_buses_from_configs -from lerobot.common.robot_devices.robots.configs import LeKiwiRobotConfig -from lerobot.common.robot_devices.robots.feetech_calibration import run_arm_manual_calibration -from lerobot.common.robot_devices.robots.utils import get_arm_id -from lerobot.common.robot_devices.utils import RobotDeviceNotConnectedError - -PYNPUT_AVAILABLE = True -try: - # Only import if there's a valid X server or if we're not on a Pi - if ("DISPLAY" not in os.environ) and ("linux" in sys.platform): - print("No DISPLAY set. Skipping pynput import.") - raise ImportError("pynput blocked intentionally due to no display.") - - from pynput import keyboard -except ImportError: - keyboard = None - PYNPUT_AVAILABLE = False -except Exception as e: - keyboard = None - PYNPUT_AVAILABLE = False - print(f"Could not import pynput: {e}") - - -class MobileManipulator: - """ - MobileManipulator is a class for connecting to and controlling a remote mobile manipulator robot. - The robot includes a three omniwheel mobile base and a remote follower arm. - The leader arm is connected locally (on the laptop) and its joint positions are recorded and then - forwarded to the remote follower arm (after applying a safety clamp). - In parallel, keyboard teleoperation is used to generate raw velocity commands for the wheels. - """ - - def __init__(self, config: LeKiwiRobotConfig): - """ - Expected keys in config: - - ip, port, video_port for the remote connection. - - calibration_dir, leader_arms, follower_arms, max_relative_target, etc. - """ - self.robot_type = config.type - self.config = config - self.remote_ip = config.ip - self.remote_port = config.port - self.remote_port_video = config.video_port - self.calibration_dir = Path(self.config.calibration_dir) - self.logs = {} - - self.teleop_keys = self.config.teleop_keys - - # For teleoperation, the leader arm (local) is used to record the desired arm pose. - self.leader_arms = make_motors_buses_from_configs(self.config.leader_arms) - - self.follower_arms = make_motors_buses_from_configs(self.config.follower_arms) - - self.cameras = make_cameras_from_configs(self.config.cameras) - - self.is_connected = False - - self.last_frames = {} - self.last_present_speed = {} - self.last_remote_arm_state = torch.zeros(6, dtype=torch.float32) - - # Define three speed levels and a current index - self.speed_levels = [ - {"xy": 0.1, "theta": 30}, # slow - {"xy": 0.2, "theta": 60}, # medium - {"xy": 0.3, "theta": 90}, # fast - ] - self.speed_index = 0 # Start at slow - - # ZeroMQ context and sockets. - self.context = None - self.cmd_socket = None - self.video_socket = None - - # Keyboard state for base teleoperation. - self.running = True - self.pressed_keys = { - "forward": False, - "backward": False, - "left": False, - "right": False, - "rotate_left": False, - "rotate_right": False, - } - - if PYNPUT_AVAILABLE: - print("pynput is available - enabling local keyboard listener.") - self.listener = keyboard.Listener( - on_press=self.on_press, - on_release=self.on_release, - ) - self.listener.start() - else: - print("pynput not available - skipping local keyboard listener.") - self.listener = None - - def get_motor_names(self, arms: dict[str, MotorsBus]) -> list: - return [f"{arm}_{motor}" for arm, bus in arms.items() for motor in bus.motors] - - @property - def camera_features(self) -> dict: - cam_ft = {} - for cam_key, cam in self.cameras.items(): - key = f"observation.images.{cam_key}" - cam_ft[key] = { - "shape": (cam.height, cam.width, cam.channels), - "names": ["height", "width", "channels"], - "info": None, - } - return cam_ft - - @property - def motor_features(self) -> dict: - follower_arm_names = [ - "shoulder_pan", - "shoulder_lift", - "elbow_flex", - "wrist_flex", - "wrist_roll", - "gripper", - ] - observations = ["x_mm", "y_mm", "theta"] - combined_names = follower_arm_names + observations - return { - "action": { - "dtype": "float32", - "shape": (len(combined_names),), - "names": combined_names, - }, - "observation.state": { - "dtype": "float32", - "shape": (len(combined_names),), - "names": combined_names, - }, - } - - @property - def features(self): - return {**self.motor_features, **self.camera_features} - - @property - def has_camera(self): - return len(self.cameras) > 0 - - @property - def num_cameras(self): - return len(self.cameras) - - @property - def available_arms(self): - available = [] - for name in self.leader_arms: - available.append(get_arm_id(name, "leader")) - for name in self.follower_arms: - available.append(get_arm_id(name, "follower")) - return available - - def on_press(self, key): - try: - # Movement - if key.char == self.teleop_keys["forward"]: - self.pressed_keys["forward"] = True - elif key.char == self.teleop_keys["backward"]: - self.pressed_keys["backward"] = True - elif key.char == self.teleop_keys["left"]: - self.pressed_keys["left"] = True - elif key.char == self.teleop_keys["right"]: - self.pressed_keys["right"] = True - elif key.char == self.teleop_keys["rotate_left"]: - self.pressed_keys["rotate_left"] = True - elif key.char == self.teleop_keys["rotate_right"]: - self.pressed_keys["rotate_right"] = True - - # Quit teleoperation - elif key.char == self.teleop_keys["quit"]: - self.running = False - return False - - # Speed control - elif key.char == self.teleop_keys["speed_up"]: - self.speed_index = min(self.speed_index + 1, 2) - print(f"Speed index increased to {self.speed_index}") - elif key.char == self.teleop_keys["speed_down"]: - self.speed_index = max(self.speed_index - 1, 0) - print(f"Speed index decreased to {self.speed_index}") - - except AttributeError: - # e.g., if key is special like Key.esc - if key == keyboard.Key.esc: - self.running = False - return False - - def on_release(self, key): - try: - if hasattr(key, "char"): - if key.char == self.teleop_keys["forward"]: - self.pressed_keys["forward"] = False - elif key.char == self.teleop_keys["backward"]: - self.pressed_keys["backward"] = False - elif key.char == self.teleop_keys["left"]: - self.pressed_keys["left"] = False - elif key.char == self.teleop_keys["right"]: - self.pressed_keys["right"] = False - elif key.char == self.teleop_keys["rotate_left"]: - self.pressed_keys["rotate_left"] = False - elif key.char == self.teleop_keys["rotate_right"]: - self.pressed_keys["rotate_right"] = False - except AttributeError: - pass - - def connect(self): - if not self.leader_arms: - raise ValueError("MobileManipulator has no leader arm to connect.") - for name in self.leader_arms: - print(f"Connecting {name} leader arm.") - self.calibrate_leader() - - # Set up ZeroMQ sockets to communicate with the remote mobile robot. - self.context = zmq.Context() - self.cmd_socket = self.context.socket(zmq.PUSH) - connection_string = f"tcp://{self.remote_ip}:{self.remote_port}" - self.cmd_socket.connect(connection_string) - self.cmd_socket.setsockopt(zmq.CONFLATE, 1) - self.video_socket = self.context.socket(zmq.PULL) - video_connection = f"tcp://{self.remote_ip}:{self.remote_port_video}" - self.video_socket.connect(video_connection) - self.video_socket.setsockopt(zmq.CONFLATE, 1) - print( - f"[INFO] Connected to remote robot at {connection_string} and video stream at {video_connection}." - ) - self.is_connected = True - - def load_or_run_calibration_(self, name, arm, arm_type): - arm_id = get_arm_id(name, arm_type) - arm_calib_path = self.calibration_dir / f"{arm_id}.json" - - if arm_calib_path.exists(): - with open(arm_calib_path) as f: - calibration = json.load(f) - else: - print(f"Missing calibration file '{arm_calib_path}'") - calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type) - print(f"Calibration is done! Saving calibration file '{arm_calib_path}'") - arm_calib_path.parent.mkdir(parents=True, exist_ok=True) - with open(arm_calib_path, "w") as f: - json.dump(calibration, f) - - return calibration - - def calibrate_leader(self): - for name, arm in self.leader_arms.items(): - # Connect the bus - arm.connect() - - # Disable torque on all motors - for motor_id in arm.motors: - arm.write("Torque_Enable", TorqueMode.DISABLED.value, motor_id) - - # Now run calibration - calibration = self.load_or_run_calibration_(name, arm, "leader") - arm.set_calibration(calibration) - - def calibrate_follower(self): - for name, bus in self.follower_arms.items(): - bus.connect() - - # Disable torque on all motors - for motor_id in bus.motors: - bus.write("Torque_Enable", 0, motor_id) - - # Then filter out wheels - arm_only_dict = {k: v for k, v in bus.motors.items() if not k.startswith("wheel_")} - if not arm_only_dict: - continue - - original_motors = bus.motors - bus.motors = arm_only_dict - - calibration = self.load_or_run_calibration_(name, bus, "follower") - bus.set_calibration(calibration) - - bus.motors = original_motors - - def _get_data(self): - """ - Polls the video socket for up to 15 ms. If data arrives, decode only - the *latest* message, returning frames, speed, and arm state. If - nothing arrives for any field, use the last known values. - """ - frames = {} - present_speed = {} - remote_arm_state_tensor = torch.zeros(6, dtype=torch.float32) - - # Poll up to 15 ms - poller = zmq.Poller() - poller.register(self.video_socket, zmq.POLLIN) - socks = dict(poller.poll(15)) - if self.video_socket not in socks or socks[self.video_socket] != zmq.POLLIN: - # No new data arrived → reuse ALL old data - return (self.last_frames, self.last_present_speed, self.last_remote_arm_state) - - # Drain all messages, keep only the last - last_msg = None - while True: - try: - obs_string = self.video_socket.recv_string(zmq.NOBLOCK) - last_msg = obs_string - except zmq.Again: - break - - if not last_msg: - # No new message → also reuse old - return (self.last_frames, self.last_present_speed, self.last_remote_arm_state) - - # Decode only the final message - try: - observation = json.loads(last_msg) - - images_dict = observation.get("images", {}) - new_speed = observation.get("present_speed", {}) - new_arm_state = observation.get("follower_arm_state", None) - - # Convert images - for cam_name, image_b64 in images_dict.items(): - if image_b64: - jpg_data = base64.b64decode(image_b64) - np_arr = np.frombuffer(jpg_data, dtype=np.uint8) - frame_candidate = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) - if frame_candidate is not None: - frames[cam_name] = frame_candidate - - # If remote_arm_state is None and frames is None there is no message then use the previous message - if new_arm_state is not None and frames is not None: - self.last_frames = frames - - remote_arm_state_tensor = torch.tensor(new_arm_state, dtype=torch.float32) - self.last_remote_arm_state = remote_arm_state_tensor - - present_speed = new_speed - self.last_present_speed = new_speed - else: - frames = self.last_frames - - remote_arm_state_tensor = self.last_remote_arm_state - - present_speed = self.last_present_speed - - except Exception as e: - print(f"[DEBUG] Error decoding video message: {e}") - # If decode fails, fall back to old data - return (self.last_frames, self.last_present_speed, self.last_remote_arm_state) - - return frames, present_speed, remote_arm_state_tensor - - def _process_present_speed(self, present_speed: dict) -> torch.Tensor: - state_tensor = torch.zeros(3, dtype=torch.int32) - if present_speed: - decoded = {key: MobileManipulator.raw_to_degps(value) for key, value in present_speed.items()} - if "1" in decoded: - state_tensor[0] = decoded["1"] - if "2" in decoded: - state_tensor[1] = decoded["2"] - if "3" in decoded: - state_tensor[2] = decoded["3"] - return state_tensor - - def teleop_step( - self, record_data: bool = False - ) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: - if not self.is_connected: - raise RobotDeviceNotConnectedError("MobileManipulator is not connected. Run `connect()` first.") - - speed_setting = self.speed_levels[self.speed_index] - xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4 - theta_speed = speed_setting["theta"] # e.g. 30, 60, or 90 - - # Prepare to assign the position of the leader to the follower - arm_positions = [] - for name in self.leader_arms: - pos = self.leader_arms[name].read("Present_Position") - pos_tensor = torch.from_numpy(pos).float() - arm_positions.extend(pos_tensor.tolist()) - - y_cmd = 0.0 # m/s forward/backward - x_cmd = 0.0 # m/s lateral - theta_cmd = 0.0 # deg/s rotation - if self.pressed_keys["forward"]: - y_cmd += xy_speed - if self.pressed_keys["backward"]: - y_cmd -= xy_speed - if self.pressed_keys["left"]: - x_cmd += xy_speed - if self.pressed_keys["right"]: - x_cmd -= xy_speed - if self.pressed_keys["rotate_left"]: - theta_cmd += theta_speed - if self.pressed_keys["rotate_right"]: - theta_cmd -= theta_speed - - wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd) - - message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions} - self.cmd_socket.send_string(json.dumps(message)) - - if not record_data: - return - - obs_dict = self.capture_observation() - - arm_state_tensor = torch.tensor(arm_positions, dtype=torch.float32) - - wheel_velocity_tuple = self.wheel_raw_to_body(wheel_commands) - wheel_velocity_mm = ( - wheel_velocity_tuple[0] * 1000.0, - wheel_velocity_tuple[1] * 1000.0, - wheel_velocity_tuple[2], - ) - wheel_tensor = torch.tensor(wheel_velocity_mm, dtype=torch.float32) - action_tensor = torch.cat([arm_state_tensor, wheel_tensor]) - action_dict = {"action": action_tensor} - - return obs_dict, action_dict - - def capture_observation(self) -> dict: - """ - Capture observations from the remote robot: current follower arm positions, - present wheel speeds (converted to body-frame velocities: x, y, theta), - and a camera frame. - """ - if not self.is_connected: - raise RobotDeviceNotConnectedError("Not connected. Run `connect()` first.") - - frames, present_speed, remote_arm_state_tensor = self._get_data() - - body_state = self.wheel_raw_to_body(present_speed) - - body_state_mm = (body_state[0] * 1000.0, body_state[1] * 1000.0, body_state[2]) # Convert x,y to mm/s - wheel_state_tensor = torch.tensor(body_state_mm, dtype=torch.float32) - combined_state_tensor = torch.cat((remote_arm_state_tensor, wheel_state_tensor), dim=0) - - obs_dict = {"observation.state": combined_state_tensor} - - # Loop over each configured camera - for cam_name, cam in self.cameras.items(): - frame = frames.get(cam_name, None) - if frame is None: - # Create a black image using the camera's configured width, height, and channels - frame = np.zeros((cam.height, cam.width, cam.channels), dtype=np.uint8) - obs_dict[f"observation.images.{cam_name}"] = torch.from_numpy(frame) - - return obs_dict - - def send_action(self, action: torch.Tensor) -> torch.Tensor: - if not self.is_connected: - raise RobotDeviceNotConnectedError("Not connected. Run `connect()` first.") - - # Ensure the action tensor has at least 9 elements: - # - First 6: arm positions. - # - Last 3: base commands. - if action.numel() < 9: - # Pad with zeros if there are not enough elements. - padded = torch.zeros(9, dtype=action.dtype) - padded[: action.numel()] = action - action = padded - - # Extract arm and base actions. - arm_actions = action[:6].flatten() - base_actions = action[6:].flatten() - - x_cmd_mm = base_actions[0].item() # mm/s - y_cmd_mm = base_actions[1].item() # mm/s - theta_cmd = base_actions[2].item() # deg/s - - # Convert mm/s to m/s for the kinematics calculations. - x_cmd = x_cmd_mm / 1000.0 # m/s - y_cmd = y_cmd_mm / 1000.0 # m/s - - # Compute wheel commands from body commands. - wheel_commands = self.body_to_wheel_raw(x_cmd, y_cmd, theta_cmd) - - arm_positions_list = arm_actions.tolist() - - message = {"raw_velocity": wheel_commands, "arm_positions": arm_positions_list} - self.cmd_socket.send_string(json.dumps(message)) - - return action - - def print_logs(self): - pass - - def disconnect(self): - if not self.is_connected: - raise RobotDeviceNotConnectedError("Not connected.") - if self.cmd_socket: - stop_cmd = { - "raw_velocity": {"left_wheel": 0, "back_wheel": 0, "right_wheel": 0}, - "arm_positions": {}, - } - self.cmd_socket.send_string(json.dumps(stop_cmd)) - self.cmd_socket.close() - if self.video_socket: - self.video_socket.close() - if self.context: - self.context.term() - if PYNPUT_AVAILABLE: - self.listener.stop() - self.is_connected = False - print("[INFO] Disconnected from remote robot.") - - def __del__(self): - if getattr(self, "is_connected", False): - self.disconnect() - if PYNPUT_AVAILABLE: - self.listener.stop() - - @staticmethod - def degps_to_raw(degps: float) -> int: - steps_per_deg = 4096.0 / 360.0 - speed_in_steps = abs(degps) * steps_per_deg - speed_int = int(round(speed_in_steps)) - if speed_int > 0x7FFF: - speed_int = 0x7FFF - if degps < 0: - return speed_int | 0x8000 - else: - return speed_int & 0x7FFF - - @staticmethod - def raw_to_degps(raw_speed: int) -> float: - steps_per_deg = 4096.0 / 360.0 - magnitude = raw_speed & 0x7FFF - degps = magnitude / steps_per_deg - if raw_speed & 0x8000: - degps = -degps - return degps - - def body_to_wheel_raw( - self, - x_cmd: float, - y_cmd: float, - theta_cmd: float, - wheel_radius: float = 0.05, - base_radius: float = 0.125, - max_raw: int = 3000, - ) -> dict: - """ - Convert desired body-frame velocities into wheel raw commands. - - Parameters: - x_cmd : Linear velocity in x (m/s). - y_cmd : Linear velocity in y (m/s). - theta_cmd : Rotational velocity (deg/s). - wheel_radius: Radius of each wheel (meters). - base_radius : Distance from the center of rotation to each wheel (meters). - max_raw : Maximum allowed raw command (ticks) per wheel. - - Returns: - A dictionary with wheel raw commands: - {"left_wheel": value, "back_wheel": value, "right_wheel": value}. - - Notes: - - Internally, the method converts theta_cmd to rad/s for the kinematics. - - The raw command is computed from the wheels angular speed in deg/s - using degps_to_raw(). If any command exceeds max_raw, all commands - are scaled down proportionally. - """ - # Convert rotational velocity from deg/s to rad/s. - theta_rad = theta_cmd * (np.pi / 180.0) - # Create the body velocity vector [x, y, theta_rad]. - velocity_vector = np.array([x_cmd, y_cmd, theta_rad]) - - # Define the wheel mounting angles (defined from y axis cw) - angles = np.radians(np.array([300, 180, 60])) - # Build the kinematic matrix: each row maps body velocities to a wheel’s linear speed. - # The third column (base_radius) accounts for the effect of rotation. - m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles]) - - # Compute each wheel’s linear speed (m/s) and then its angular speed (rad/s). - wheel_linear_speeds = m.dot(velocity_vector) - wheel_angular_speeds = wheel_linear_speeds / wheel_radius - - # Convert wheel angular speeds from rad/s to deg/s. - wheel_degps = wheel_angular_speeds * (180.0 / np.pi) - - # Scaling - steps_per_deg = 4096.0 / 360.0 - raw_floats = [abs(degps) * steps_per_deg for degps in wheel_degps] - max_raw_computed = max(raw_floats) - if max_raw_computed > max_raw: - scale = max_raw / max_raw_computed - wheel_degps = wheel_degps * scale - - # Convert each wheel’s angular speed (deg/s) to a raw integer. - wheel_raw = [MobileManipulator.degps_to_raw(deg) for deg in wheel_degps] - - return {"left_wheel": wheel_raw[0], "back_wheel": wheel_raw[1], "right_wheel": wheel_raw[2]} - - def wheel_raw_to_body( - self, wheel_raw: dict, wheel_radius: float = 0.05, base_radius: float = 0.125 - ) -> tuple: - """ - Convert wheel raw command feedback back into body-frame velocities. - - Parameters: - wheel_raw : Dictionary with raw wheel commands (keys: "left_wheel", "back_wheel", "right_wheel"). - wheel_radius: Radius of each wheel (meters). - base_radius : Distance from the robot center to each wheel (meters). - - Returns: - A tuple (x_cmd, y_cmd, theta_cmd) where: - x_cmd : Linear velocity in x (m/s). - y_cmd : Linear velocity in y (m/s). - theta_cmd : Rotational velocity in deg/s. - """ - # Extract the raw values in order. - raw_list = [ - int(wheel_raw.get("left_wheel", 0)), - int(wheel_raw.get("back_wheel", 0)), - int(wheel_raw.get("right_wheel", 0)), - ] - - # Convert each raw command back to an angular speed in deg/s. - wheel_degps = np.array([MobileManipulator.raw_to_degps(r) for r in raw_list]) - # Convert from deg/s to rad/s. - wheel_radps = wheel_degps * (np.pi / 180.0) - # Compute each wheel’s linear speed (m/s) from its angular speed. - wheel_linear_speeds = wheel_radps * wheel_radius - - # Define the wheel mounting angles (defined from y axis cw) - angles = np.radians(np.array([300, 180, 60])) - m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles]) - - # Solve the inverse kinematics: body_velocity = M⁻¹ · wheel_linear_speeds. - m_inv = np.linalg.inv(m) - velocity_vector = m_inv.dot(wheel_linear_speeds) - x_cmd, y_cmd, theta_rad = velocity_vector - theta_cmd = theta_rad * (180.0 / np.pi) - return (x_cmd, y_cmd, theta_cmd) - - -class LeKiwi: - def __init__(self, motor_bus): - """ - Initializes the LeKiwi with Feetech motors bus. - """ - self.motor_bus = motor_bus - self.motor_ids = ["left_wheel", "back_wheel", "right_wheel"] - - # Initialize motors in velocity mode. - self.motor_bus.write("Lock", 0) - self.motor_bus.write("Mode", [1, 1, 1], self.motor_ids) - self.motor_bus.write("Lock", 1) - print("Motors set to velocity mode.") - - def read_velocity(self): - """ - Reads the raw speeds for all wheels. Returns a dictionary with motor names: - """ - raw_speeds = self.motor_bus.read("Present_Speed", self.motor_ids) - return { - "left_wheel": int(raw_speeds[0]), - "back_wheel": int(raw_speeds[1]), - "right_wheel": int(raw_speeds[2]), - } - - def set_velocity(self, command_speeds): - """ - Sends raw velocity commands (16-bit encoded values) directly to the motor bus. - The order of speeds must correspond to self.motor_ids. - """ - self.motor_bus.write("Goal_Speed", command_speeds, self.motor_ids) - - def stop(self): - """Stops the robot by setting all motor speeds to zero.""" - self.motor_bus.write("Goal_Speed", [0, 0, 0], self.motor_ids) - print("Motors stopped.") diff --git a/lerobot/common/robot_devices/robots/stretch.py b/lerobot/common/robot_devices/robots/stretch.py deleted file mode 100644 index 9cfe6e490..000000000 --- a/lerobot/common/robot_devices/robots/stretch.py +++ /dev/null @@ -1,208 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import time -from dataclasses import replace - -import torch -from stretch_body.gamepad_teleop import GamePadTeleop -from stretch_body.robot import Robot as StretchAPI -from stretch_body.robot_params import RobotParams - -from lerobot.common.robot_devices.robots.configs import StretchRobotConfig - - -class StretchRobot(StretchAPI): - """Wrapper of stretch_body.robot.Robot""" - - def __init__(self, config: StretchRobotConfig | None = None, **kwargs): - super().__init__() - if config is None: - self.config = StretchRobotConfig(**kwargs) - else: - # Overwrite config arguments using kwargs - self.config = replace(config, **kwargs) - - self.robot_type = self.config.type - self.cameras = self.config.cameras - self.is_connected = False - self.teleop = None - self.logs = {} - - # TODO(aliberts): test this - RobotParams.set_logging_level("WARNING") - RobotParams.set_logging_formatter("brief_console_formatter") - - self.state_keys = None - self.action_keys = None - - def connect(self) -> None: - self.is_connected = self.startup() - if not self.is_connected: - print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'") - raise ConnectionError() - - for name in self.cameras: - self.cameras[name].connect() - self.is_connected = self.is_connected and self.cameras[name].is_connected - - if not self.is_connected: - print("Could not connect to the cameras, check that all cameras are plugged-in.") - raise ConnectionError() - - self.run_calibration() - - def run_calibration(self) -> None: - if not self.is_homed(): - self.home() - - def teleop_step( - self, record_data=False - ) -> None | tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: - # TODO(aliberts): return ndarrays instead of torch.Tensors - if not self.is_connected: - raise ConnectionError() - - if self.teleop is None: - self.teleop = GamePadTeleop(robot_instance=False) - self.teleop.startup(robot=self) - - before_read_t = time.perf_counter() - state = self.get_state() - action = self.teleop.gamepad_controller.get_state() - self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t - - before_write_t = time.perf_counter() - self.teleop.do_motion(robot=self) - self.push_command() - self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t - - if self.state_keys is None: - self.state_keys = list(state) - - if not record_data: - return - - state = torch.as_tensor(list(state.values())) - action = torch.as_tensor(list(action.values())) - - # Capture images from cameras - images = {} - for name in self.cameras: - before_camread_t = time.perf_counter() - images[name] = self.cameras[name].async_read() - images[name] = torch.from_numpy(images[name]) - self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] - self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t - - # Populate output dictionaries - obs_dict, action_dict = {}, {} - obs_dict["observation.state"] = state - action_dict["action"] = action - for name in self.cameras: - obs_dict[f"observation.images.{name}"] = images[name] - - return obs_dict, action_dict - - def get_state(self) -> dict: - status = self.get_status() - return { - "head_pan.pos": status["head"]["head_pan"]["pos"], - "head_tilt.pos": status["head"]["head_tilt"]["pos"], - "lift.pos": status["lift"]["pos"], - "arm.pos": status["arm"]["pos"], - "wrist_pitch.pos": status["end_of_arm"]["wrist_pitch"]["pos"], - "wrist_roll.pos": status["end_of_arm"]["wrist_roll"]["pos"], - "wrist_yaw.pos": status["end_of_arm"]["wrist_yaw"]["pos"], - "gripper.pos": status["end_of_arm"]["stretch_gripper"]["pos"], - "base_x.vel": status["base"]["x_vel"], - "base_y.vel": status["base"]["y_vel"], - "base_theta.vel": status["base"]["theta_vel"], - } - - def capture_observation(self) -> dict: - # TODO(aliberts): return ndarrays instead of torch.Tensors - before_read_t = time.perf_counter() - state = self.get_state() - self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t - - if self.state_keys is None: - self.state_keys = list(state) - - state = torch.as_tensor(list(state.values())) - - # Capture images from cameras - images = {} - for name in self.cameras: - before_camread_t = time.perf_counter() - images[name] = self.cameras[name].async_read() - images[name] = torch.from_numpy(images[name]) - self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] - self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t - - # Populate output dictionaries - obs_dict = {} - obs_dict["observation.state"] = state - for name in self.cameras: - obs_dict[f"observation.images.{name}"] = images[name] - - return obs_dict - - def send_action(self, action: torch.Tensor) -> torch.Tensor: - # TODO(aliberts): return ndarrays instead of torch.Tensors - if not self.is_connected: - raise ConnectionError() - - if self.teleop is None: - self.teleop = GamePadTeleop(robot_instance=False) - self.teleop.startup(robot=self) - - if self.action_keys is None: - dummy_action = self.teleop.gamepad_controller.get_state() - self.action_keys = list(dummy_action.keys()) - - action_dict = dict(zip(self.action_keys, action.tolist(), strict=True)) - - before_write_t = time.perf_counter() - self.teleop.do_motion(state=action_dict, robot=self) - self.push_command() - self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t - - # TODO(aliberts): return action_sent when motion is limited - return action - - def print_logs(self) -> None: - pass - # TODO(aliberts): move robot-specific logs logic here - - def teleop_safety_stop(self) -> None: - if self.teleop is not None: - self.teleop._safety_stop(robot=self) - - def disconnect(self) -> None: - self.stop() - if self.teleop is not None: - self.teleop.gamepad_controller.stop() - self.teleop.stop() - - if len(self.cameras) > 0: - for cam in self.cameras.values(): - cam.disconnect() - - self.is_connected = False - - def __del__(self): - self.disconnect() diff --git a/lerobot/common/robot_devices/robots/utils.py b/lerobot/common/robot_devices/robots/utils.py deleted file mode 100644 index dab514d5e..000000000 --- a/lerobot/common/robot_devices/robots/utils.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Protocol - -from lerobot.common.robot_devices.robots.configs import ( - AlohaRobotConfig, - KochBimanualRobotConfig, - KochRobotConfig, - LeKiwiRobotConfig, - ManipulatorRobotConfig, - MossRobotConfig, - RobotConfig, - So100RobotConfig, - StretchRobotConfig, -) - - -def get_arm_id(name, arm_type): - """Returns the string identifier of a robot arm. For instance, for a bimanual manipulator - like Aloha, it could be left_follower, right_follower, left_leader, or right_leader. - """ - return f"{name}_{arm_type}" - - -class Robot(Protocol): - # TODO(rcadene, aliberts): Add unit test checking the protocol is implemented in the corresponding classes - robot_type: str - features: dict - - def connect(self): ... - def run_calibration(self): ... - def teleop_step(self, record_data=False): ... - def capture_observation(self): ... - def send_action(self, action): ... - def disconnect(self): ... - - -def make_robot_config(robot_type: str, **kwargs) -> RobotConfig: - if robot_type == "aloha": - return AlohaRobotConfig(**kwargs) - elif robot_type == "koch": - return KochRobotConfig(**kwargs) - elif robot_type == "koch_bimanual": - return KochBimanualRobotConfig(**kwargs) - elif robot_type == "moss": - return MossRobotConfig(**kwargs) - elif robot_type == "so100": - return So100RobotConfig(**kwargs) - elif robot_type == "stretch": - return StretchRobotConfig(**kwargs) - elif robot_type == "lekiwi": - return LeKiwiRobotConfig(**kwargs) - else: - raise ValueError(f"Robot type '{robot_type}' is not available.") - - -def make_robot_from_config(config: RobotConfig): - if isinstance(config, ManipulatorRobotConfig): - from lerobot.common.robot_devices.robots.manipulator import ManipulatorRobot - - return ManipulatorRobot(config) - elif isinstance(config, LeKiwiRobotConfig): - from lerobot.common.robot_devices.robots.mobile_manipulator import MobileManipulator - - return MobileManipulator(config) - else: - from lerobot.common.robot_devices.robots.stretch import StretchRobot - - return StretchRobot(config) - - -def make_robot(robot_type: str, **kwargs) -> Robot: - config = make_robot_config(robot_type, **kwargs) - return make_robot_from_config(config) diff --git a/lerobot/common/robots/__init__.py b/lerobot/common/robots/__init__.py new file mode 100644 index 000000000..d8fd0de93 --- /dev/null +++ b/lerobot/common/robots/__init__.py @@ -0,0 +1,3 @@ +from .config import RobotConfig +from .robot import Robot +from .utils import make_robot_from_config diff --git a/lerobot/common/robots/config.py b/lerobot/common/robots/config.py new file mode 100644 index 000000000..a85a83169 --- /dev/null +++ b/lerobot/common/robots/config.py @@ -0,0 +1,40 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from dataclasses import dataclass +from pathlib import Path + +import draccus + + +@dataclass(kw_only=True) +class RobotConfig(draccus.ChoiceRegistry, abc.ABC): + # Allows to distinguish between different robots of the same type + id: str | None = None + # Directory to store calibration file + calibration_dir: Path | None = None + + def __post_init__(self): + if hasattr(self, "cameras") and self.cameras: + for _, config in self.cameras.items(): + for attr in ["width", "height", "fps"]: + if getattr(config, attr) is None: + raise ValueError( + f"Specifying '{attr}' is required for the camera to be used in a robot" + ) + + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) diff --git a/lerobot/common/robots/koch_follower/__init__.py b/lerobot/common/robots/koch_follower/__init__.py new file mode 100644 index 000000000..ae98a2c38 --- /dev/null +++ b/lerobot/common/robots/koch_follower/__init__.py @@ -0,0 +1,2 @@ +from .config_koch_follower import KochFollowerConfig +from .koch_follower import KochFollower diff --git a/lerobot/common/robots/koch_follower/config_koch_follower.py b/lerobot/common/robots/koch_follower/config_koch_follower.py new file mode 100644 index 000000000..6ac164726 --- /dev/null +++ b/lerobot/common/robots/koch_follower/config_koch_follower.py @@ -0,0 +1,39 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.common.cameras import CameraConfig + +from ..config import RobotConfig + + +@RobotConfig.register_subclass("koch_follower") +@dataclass +class KochFollowerConfig(RobotConfig): + # Port to connect to the arm + port: str + + disable_torque_on_disconnect: bool = True + + # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. + # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as + # the number of motors in your follower arms. + max_relative_target: int | None = None + + # cameras + cameras: dict[str, CameraConfig] = field(default_factory=dict) + + # Set to `True` for backward compatibility with previous policies/dataset + use_degrees: bool = False diff --git a/lerobot/common/robots/koch_follower/koch.mdx b/lerobot/common/robots/koch_follower/koch.mdx new file mode 100644 index 000000000..c39865944 --- /dev/null +++ b/lerobot/common/robots/koch_follower/koch.mdx @@ -0,0 +1,258 @@ +# Koch v1.1 + +In the steps below, we explain how to assemble the Koch v1.1 robot. + +## Order and assemble the parts + +Follow the sourcing and assembling instructions provided in this [README](https://github.com/jess-moss/koch-v1-1). This will guide you through setting up both the follower and leader arms, as shown in the image below. + +For a visual walkthrough of the assembly process, you can refer to [this video tutorial](https://youtu.be/8nQIg9BwwTk). + +> [!WARNING] +> Since the production of this video, we simplified the configuration phase. Because of this, two things differ from the instructions in that video: +> - Don't plug in all the motor cables right away and wait to be instructed to do so in [Configure the motors](#configure-the-motors). +> - Don't screw in the controller board (PCB) to the base right away and wait for being instructed to do so in [Configure the motors](#configure-the-motors). + + +## Install LeRobot 🤗 + +To install LeRobot follow, our [Installation Guide](./installation) + +In addition to these instructions, you need to install the Dynamixel SDK: +```bash +pip install -e ".[dynamixel]" +``` + +## Configure the motors + +### 1. Find the USB ports associated with each arm + +To find the port for each bus servo adapter, run this script: +```bash +python lerobot/find_port.py +``` + + + + +Example output: + +``` +Finding all available ports for the MotorBus. +['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] +Remove the USB cable from your MotorsBus and press Enter when done. + +[...Disconnect corresponding leader or follower arm and press Enter...] + +The port of this MotorsBus is /dev/tty.usbmodem575E0032081 +Reconnect the USB cable. +``` + +Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm. + + + + +On Linux, you might need to give access to the USB ports by running: +```bash +sudo chmod 666 /dev/ttyACM0 +sudo chmod 666 /dev/ttyACM1 +``` + +Example output: + +``` +Finding all available ports for the MotorBus. +['/dev/ttyACM0', '/dev/ttyACM1'] +Remove the usb cable from your MotorsBus and press Enter when done. + +[...Disconnect corresponding leader or follower arm and press Enter...] + +The port of this MotorsBus is /dev/ttyACM1 +Reconnect the USB cable. +``` + +Where the found port is: `/dev/ttyACM1` corresponding to your leader or follower arm. + + + + +### 2. Set the motors ids and baudrates + +Each motor is identified by a unique id on the bus. When brand new, motors usually come with a default id of `1`. For the communication to work properly between the motors and the controller, we first need to set a unique, different id to each motor. Additionally, the speed at which data is transmitted on the bus is determined by the baudrate. In order to talk to each other, the controller and all the motors need to be configured with the same baudrate. + +To that end, we first need to connect to each motor individually with the controller in order to set these. Since we will write these parameters in the non-volatile section of the motors' internal memory (EEPROM), we'll only need to do this once. + +If you are repurposing motors from another robot, you will probably also need to perform this step, as the ids and baudrate likely won't match. + +#### Follower + +Connect the usb cable from your computer and the 5V power supply to the follower arm's controller board. Then, run the following command or run the API example with the port you got from the previous step. You'll also need to give your leader arm a name with the `id` parameter. + +For a visual reference on how to set the motor ids please refer to [this video](https://huggingface.co/docs/lerobot/en/so101#setup-motors-video) where we follow the process for the SO101 arm. + + + + +```bash +python -m lerobot.setup_motors \ + --robot.type=koch_follower \ + --robot.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step +``` + + + +```python +from lerobot.common.robots.koch_follower import KochFollower, KochFollowerConfig + +config = KochFollowerConfig( + port="/dev/tty.usbmodem575E0031751", + id="my_awesome_follower_arm", +) +follower = KochFollower(config) +follower.setup_motors() +``` + + + +You should see the following instruction. +``` +Connect the controller board to the 'gripper' motor only and press enter. +``` + +As instructed, plug the gripper's motor. Make sure it's the only motor connected to the board, and that the motor itself is not yet daisy-chained to any other motor. As you press `[Enter]`, the script will automatically set the id and baudrate for that motor. + +
+Troubleshooting + + If you get an error at that point, check your cables and make sure they are plugged in properly: +
    +
  • Power supply
  • +
  • USB cable between your computer and the controller board
  • +
  • The 3-pin cable from the controller board to the motor
  • +
+ + If you are using a Waveshare controller board, make sure that the two jumpers are set on the `B` channel (USB). +
+ +You should then see the following message: +``` +'gripper' motor id set to 6 +``` + +Followed by the next instruction: +``` +Connect the controller board to the 'wrist_roll' motor only and press enter. +``` + +You can disconnect the 3-pin cable from the controller board but you can leave it connected to the gripper motor on the other end as it will already be in the right place. Now, plug in another 3-pin cable to the wrist roll motor and connect it to the controller board. As with the previous motor, make sure it is the only motor connected to the board and that the motor itself isn't connected to any other one. + +Repeat the operation for each motor as instructed. + +> [!TIP] +> Check your cabling at each step before pressing Enter. For instance, the power supply cable might disconnect as you manipulate the board. + +When you are done, the script will simply finish, at which point the motors are ready to be used. You can now plug the 3-pin cable from each motor to the next one, and the cable from the first motor (the 'shoulder pan' with id=1) to the controller board, which can now be attached to the base of the arm. + +#### Leader +Do the same steps for the leader arm but modify the command or script accordingly. + + + + +```bash +python -m lerobot.setup_motors \ + --teleop.type=koch_leader \ + --teleop.port=/dev/tty.usbmodem575E0031751 \ # <- paste here the port found at previous step +``` + + + +```python +from lerobot.common.teleoperators.koch_leader import KochLeader, KochLeaderConfig + +config = KochLeaderConfig( + port="/dev/tty.usbmodem575E0031751", + id="my_awesome_leader_arm", +) +leader = KochLeader(config) +leader.setup_motors() +``` + + + +## Calibrate + +Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. +The calibration process is very important because it allows a neural network trained on one robot to work on another. + +#### Follower + +Run the following command or API example to calibrate the follower arm: + + + + +```bash +python -m lerobot.calibrate \ + --robot.type=koch_follower \ + --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot + --robot.id=my_awesome_follower_arm # <- Give the robot a unique name +``` + + + +```python +from lerobot.common.robots.koch_follower import KochFollowerConfig, KochFollower + +config = KochFollowerConfig( + port="/dev/tty.usbmodem585A0076891", + id="my_awesome_follower_arm", +) + +follower = KochFollower(config) +follower.connect(calibrate=False) +follower.calibrate() +follower.disconnect() +``` + + + +We unified the calibration method for most robots. Thus, the calibration steps for this Koch arm are the same as the steps for the SO100 and SO101. First, we have to move the robot to the position where each joint is in the middle of its range, then we press `Enter`. Secondly, we move all joints through their full range of motion. A video of this same process for the SO101 as reference can be found [here](https://huggingface.co/docs/lerobot/en/so101#calibration-video). + +#### Leader + +Do the same steps to calibrate the leader arm, run the following command or API example: + + + + +```bash +python -m lerobot.calibrate \ + --teleop.type=koch_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot + --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name +``` + + + +```python +from lerobot.common.teleoperators.koch_leader import KochLeaderConfig, KochLeader + +config = KochLeaderConfig( + port="/dev/tty.usbmodem575E0031751", + id="my_awesome_leader_arm", +) + +leader = KochLeader(config) +leader.connect(calibrate=False) +leader.calibrate() +leader.disconnect() +``` + + + +Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot) + +> [!TIP] +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). diff --git a/lerobot/common/robots/koch_follower/koch_follower.py b/lerobot/common/robots/koch_follower/koch_follower.py new file mode 100644 index 000000000..64ece25f2 --- /dev/null +++ b/lerobot/common/robots/koch_follower/koch_follower.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from functools import cached_property +from typing import Any + +from lerobot.common.cameras.utils import make_cameras_from_configs +from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.common.motors.dynamixel import ( + DynamixelMotorsBus, + OperatingMode, +) + +from ..robot import Robot +from ..utils import ensure_safe_goal_position +from .config_koch_follower import KochFollowerConfig + +logger = logging.getLogger(__name__) + + +class KochFollower(Robot): + """ + - [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow + expansion, developed by Alexander Koch from [Tau Robotics](https://tau-robotics.com) + - [Koch v1.1](https://github.com/jess-moss/koch-v1-1) developed by Jess Moss + """ + + config_class = KochFollowerConfig + name = "koch_follower" + + def __init__(self, config: KochFollowerConfig): + super().__init__(config) + self.config = config + norm_mode_body = MotorNormMode.DEGREES if config.use_degrees else MotorNormMode.RANGE_M100_100 + self.bus = DynamixelMotorsBus( + port=self.config.port, + motors={ + "shoulder_pan": Motor(1, "xl430-w250", norm_mode_body), + "shoulder_lift": Motor(2, "xl430-w250", norm_mode_body), + "elbow_flex": Motor(3, "xl330-m288", norm_mode_body), + "wrist_flex": Motor(4, "xl330-m288", norm_mode_body), + "wrist_roll": Motor(5, "xl330-m288", norm_mode_body), + "gripper": Motor(6, "xl330-m288", MotorNormMode.RANGE_0_100), + }, + calibration=self.calibration, + ) + self.cameras = make_cameras_from_configs(config.cameras) + + @property + def _motors_ft(self) -> dict[str, type]: + return {f"{motor}.pos": float for motor in self.bus.motors} + + @property + def _cameras_ft(self) -> dict[str, tuple]: + return { + cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras + } + + @cached_property + def observation_features(self) -> dict[str, type | tuple]: + return {**self._motors_ft, **self._cameras_ft} + + @cached_property + def action_features(self) -> dict[str, type]: + return self._motors_ft + + @property + def is_connected(self) -> bool: + return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + + def connect(self, calibrate: bool = True) -> None: + """ + We assume that at connection time, arm is in a rest position, + and torque can be safely disabled to run calibration. + """ + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + self.bus.connect() + if not self.is_calibrated and calibrate: + self.calibrate() + + for cam in self.cameras.values(): + cam.connect() + + self.configure() + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + return self.bus.is_calibrated + + def calibrate(self) -> None: + logger.info(f"\nRunning calibration of {self}") + self.bus.disable_torque() + for motor in self.bus.motors: + self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value) + + input(f"Move {self} to the middle of its range of motion and press ENTER....") + homing_offsets = self.bus.set_half_turn_homings() + + full_turn_motors = ["shoulder_pan", "wrist_roll"] + unknown_range_motors = [motor for motor in self.bus.motors if motor not in full_turn_motors] + print( + f"Move all joints except {full_turn_motors} sequentially through their entire " + "ranges of motion.\nRecording positions. Press ENTER to stop..." + ) + range_mins, range_maxes = self.bus.record_ranges_of_motion(unknown_range_motors) + for motor in full_turn_motors: + range_mins[motor] = 0 + range_maxes[motor] = 4095 + + self.calibration = {} + for motor, m in self.bus.motors.items(): + self.calibration[motor] = MotorCalibration( + id=m.id, + drive_mode=0, + homing_offset=homing_offsets[motor], + range_min=range_mins[motor], + range_max=range_maxes[motor], + ) + + self.bus.write_calibration(self.calibration) + self._save_calibration() + logger.info(f"Calibration saved to {self.calibration_fpath}") + + def configure(self) -> None: + with self.bus.torque_disabled(): + self.bus.configure_motors() + # Use 'extended position mode' for all motors except gripper, because in joint mode the servos + # can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling + # the arm, you could end up with a servo with a position 0 or 4095 at a crucial point + for motor in self.bus.motors: + if motor != "gripper": + self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value) + + # Use 'position control current based' for gripper to be limited by the limit of the current. For + # the follower gripper, it means it can grasp an object without forcing too much even tho, its + # goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch). + # For the leader gripper, it means we can use it as a physical trigger, since we can force with + # our finger to make it move, and it will move back to its original target position when we + # release the force. + self.bus.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value) + + # Set better PID values to close the gap between recorded states and actions + # TODO(rcadene): Implement an automatic procedure to set optimal PID values for each motor + self.bus.write("Position_P_Gain", "elbow_flex", 1500) + self.bus.write("Position_I_Gain", "elbow_flex", 0) + self.bus.write("Position_D_Gain", "elbow_flex", 600) + + def setup_motors(self) -> None: + for motor in reversed(self.bus.motors): + input(f"Connect the controller board to the '{motor}' motor only and press enter.") + self.bus.setup_motor(motor) + print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + + def get_observation(self) -> dict[str, Any]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + # Read arm position + start = time.perf_counter() + obs_dict = self.bus.sync_read("Present_Position") + obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()} + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read state: {dt_ms:.1f}ms") + + # Capture images from cameras + for cam_key, cam in self.cameras.items(): + start = time.perf_counter() + obs_dict[cam_key] = cam.async_read() + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") + + return obs_dict + + def send_action(self, action: dict[str, float]) -> dict[str, float]: + """Command arm to move to a target joint configuration. + + The relative action magnitude may be clipped depending on the configuration parameter + `max_relative_target`. In this case, the action sent differs from original action. + Thus, this function always returns the action actually sent. + + Args: + action (dict[str, float]): The goal positions for the motors. + + Returns: + dict[str, float]: The action sent to the motors, potentially clipped. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} + + # Cap goal position when too far away from present position. + # /!\ Slower fps expected due to reading from the follower. + if self.config.max_relative_target is not None: + present_pos = self.bus.sync_read("Present_Position") + goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()} + goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target) + + # Send goal position to the arm + self.bus.sync_write("Goal_Position", goal_pos) + return {f"{motor}.pos": val for motor, val in goal_pos.items()} + + def disconnect(self): + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + self.bus.disconnect(self.config.disable_torque_on_disconnect) + for cam in self.cameras.values(): + cam.disconnect() + + logger.info(f"{self} disconnected.") diff --git a/lerobot/common/robots/lekiwi/__init__.py b/lerobot/common/robots/lekiwi/__init__.py new file mode 100644 index 000000000..e3d10c5c1 --- /dev/null +++ b/lerobot/common/robots/lekiwi/__init__.py @@ -0,0 +1,3 @@ +from .config_lekiwi import LeKiwiClientConfig, LeKiwiConfig +from .lekiwi import LeKiwi +from .lekiwi_client import LeKiwiClient diff --git a/lerobot/common/robots/lekiwi/config_lekiwi.py b/lerobot/common/robots/lekiwi/config_lekiwi.py new file mode 100644 index 000000000..022d09cdd --- /dev/null +++ b/lerobot/common/robots/lekiwi/config_lekiwi.py @@ -0,0 +1,96 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.common.cameras.configs import CameraConfig, Cv2Rotation +from lerobot.common.cameras.opencv.configuration_opencv import OpenCVCameraConfig + +from ..config import RobotConfig + + +def lekiwi_cameras_config() -> dict[str, CameraConfig]: + return { + "front": OpenCVCameraConfig( + index_or_path="/dev/video0", fps=30, width=640, height=480, rotation=Cv2Rotation.ROTATE_180 + ), + "wrist": OpenCVCameraConfig( + index_or_path="/dev/video2", fps=30, width=480, height=640, rotation=Cv2Rotation.ROTATE_90 + ), + } + + +@RobotConfig.register_subclass("lekiwi") +@dataclass +class LeKiwiConfig(RobotConfig): + port: str = "/dev/ttyACM0" # port to connect to the bus + + disable_torque_on_disconnect: bool = True + + # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. + # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as + # the number of motors in your follower arms. + max_relative_target: int | None = None + + cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config) + + # Set to `True` for backward compatibility with previous policies/dataset + use_degrees: bool = False + + +@dataclass +class LeKiwiHostConfig: + # Network Configuration + port_zmq_cmd: int = 5555 + port_zmq_observations: int = 5556 + + # Duration of the application + connection_time_s: int = 30 + + # Watchdog: stop the robot if no command is received for over 0.5 seconds. + watchdog_timeout_ms: int = 500 + + # If robot jitters decrease the frequency and monitor cpu load with `top` in cmd + max_loop_freq_hz: int = 30 + + +@RobotConfig.register_subclass("lekiwi_client") +@dataclass +class LeKiwiClientConfig(RobotConfig): + # Network Configuration + remote_ip: str + port_zmq_cmd: int = 5555 + port_zmq_observations: int = 5556 + + teleop_keys: dict[str, str] = field( + default_factory=lambda: { + # Movement + "forward": "w", + "backward": "s", + "left": "a", + "right": "d", + "rotate_left": "z", + "rotate_right": "x", + # Speed control + "speed_up": "r", + "speed_down": "f", + # quit teleop + "quit": "q", + } + ) + + cameras: dict[str, CameraConfig] = field(default_factory=lekiwi_cameras_config) + + polling_timeout_ms: int = 15 + connect_timeout_s: int = 5 diff --git a/lerobot/common/robots/lekiwi/lekiwi.mdx b/lerobot/common/robots/lekiwi/lekiwi.mdx new file mode 100644 index 000000000..6eaebce79 --- /dev/null +++ b/lerobot/common/robots/lekiwi/lekiwi.mdx @@ -0,0 +1,300 @@ +# LeKiwi + +In the steps below, we explain how to assemble the LeKiwi mobile robot. + +## Source the parts + +Follow this [README](https://github.com/SIGRobotics-UIUC/LeKiwi). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts. +And advise if it's your first time printing or if you don't own a 3D printer. + +### Wired version +If you have the **wired** LeKiwi version, you can skip the installation of the Raspberry Pi and setting up SSH. You can also run all commands directly on your PC for both the LeKiwi scripts and the leader arm scripts for teleoperating. + +## Install software on Pi +Now we have to set up the remote PC that will run on the LeKiwi Robot. This is normally a Raspberry Pi, but can be any PC that can run on 5V and has enough usb ports (2 or more) for the cameras and motor control board. + +### Install OS +For setting up the Raspberry Pi and its SD-card see: [Setup PI](https://www.raspberrypi.com/documentation/computers/getting-started.html). Here is explained how to download the [Imager](https://www.raspberrypi.com/software/) to install Raspberry Pi OS or Ubuntu. + +### Setup SSH +After setting up your Pi, you should enable and set up [SSH](https://www.raspberrypi.com/news/coding-on-raspberry-pi-remotely-with-visual-studio-code/) (Secure Shell Protocol) so you can log in to the Pi from your laptop without requiring a screen, keyboard, and mouse on the Pi. A great tutorial on how to do this can be found [here](https://www.raspberrypi.com/documentation/computers/remote-access.html#ssh). Logging into your Pi can be done in your Command Prompt (cmd) or, if you use VSCode you can use [this](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-ssh) extension. + +### Install LeRobot on Pi 🤗 + +On your Raspberry Pi install LeRobot using our [Installation Guide](./installation) + +In addition to these instructions, you need to install the Feetech sdk on your Pi: +```bash +pip install -e ".[feetech]" +``` + +## Install LeRobot locally +If you already have installed LeRobot on your laptop/pc you can skip this step; otherwise, please follow along as we do the same steps we did on the Pi. + +Follow our [Installation Guide](./installation) + +Great :hugs:! You are now done installing LeRobot, and we can begin assembling the SO100/SO101 arms and the mobile base :robot:. +Every time you now want to use LeRobot, you can go to the `~/lerobot` folder where we installed LeRobot and run one of the commands. + +# Step-by-Step Assembly Instructions + +First, we will assemble the two SO100/SO101 arms. One to attach to the mobile base and one for teleoperation. Then we will assemble the mobile base. The instructions for assembling can be found on these two pages: + +- [Assemble SO101](./so101#step-by-step-assembly-instructions) +- [Assemble LeKiwi](https://github.com/SIGRobotics-UIUC/LeKiwi/blob/main/Assembly.md) + +### Find the USB ports associated with motor board + +To find the port for each bus servo adapter, run this script: +```bash +python lerobot/find_port.py +``` + + + + +Example output: + +``` +Finding all available ports for the MotorBus. +['/dev/tty.usbmodem575E0032081'] +Remove the USB cable from your MotorsBus and press Enter when done. + +[...Disconnect corresponding leader or follower arm and press Enter...] + +The port of this MotorsBus is /dev/tty.usbmodem575E0032081 +Reconnect the USB cable. +``` + +Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your board. + + + + +On Linux, you might need to give access to the USB ports by running: +```bash +sudo chmod 666 /dev/ttyACM0 +sudo chmod 666 /dev/ttyACM1 +``` + +Example output: + +``` +Finding all available ports for the MotorBus. +['/dev/ttyACM0'] +Remove the usb cable from your MotorsBus and press Enter when done. + +[...Disconnect corresponding leader or follower arm and press Enter...] + +The port of this MotorsBus is /dev/ttyACM0 +Reconnect the USB cable. +``` + +Where the found port is: `/dev/ttyACM0` corresponding to your board. + + + + +### Configure motors +The instructions for configuring the motors can be found in the SO101 [docs](./so101#configure-the-motors). Besides the ids for the arm motors, we also need to set the motor ids for the mobile base. These need to be in a specific order to work. Below an image of the motor ids and motor mounting positions for the mobile base. Note that we only use one Motor Control board on LeKiwi. This means the motor ids for the wheels are 7, 8 and 9. + +You can run this command to setup motors for LeKiwi. It will first setup the motors for arm (id 6..1) and then setup motors for wheels (9,8,7) + +```bash +python -m lerobot.setup_motors \ + --robot.type=lekiwi \ + --robot.port=/dev/tty.usbmodem58760431551 # <- paste here the port found at previous step +``` + +Motor ID's for mobile robot + +### Troubleshoot communication + +If you are having trouble connecting to the Mobile SO100, follow these steps to diagnose and resolve the issue. + +#### 1. Verify IP Address Configuration +Make sure that the correct IP for the Pi is used in the commands or in your code. To check the Raspberry Pi's IP address, run (on the Pi command line): +```bash +hostname -I +``` + +#### 2. Check if Pi is reachable from laptop/pc +Try pinging the Raspberry Pi from your laptop: +```bach +ping +``` + +If the ping fails: +- Ensure the Pi is powered on and connected to the same network. +- Check if SSH is enabled on the Pi. + +#### 3. Try SSH connection +If you can't SSH into the Pi, it might not be properly connected. Use: +```bash +ssh @ +``` +If you get a connection error: +- Ensure SSH is enabled on the Pi by running: + ```bash + sudo raspi-config + ``` + Then navigate to: **Interfacing Options -> SSH** and enable it. + +### Calibration + +Now we have to calibrate the leader arm and the follower arm. The wheel motors don't have to be calibrated. +The calibration process is very important because it allows a neural network trained on one robot to work on another. + +### Calibrate follower arm (on mobile base) + +Make sure the arm is connected to the Raspberry Pi and run this script or API example (on the Raspberry Pi via SSH) to launch calibration of the follower arm: + +```bash +python -m lerobot.calibrate \ + --robot.type=lekiwi \ + --robot.id=my_awesome_kiwi # <- Give the robot a unique name +``` + +We unified the calibration method for most robots, thus, the calibration steps for this SO100 arm are the same as the steps for the Koch and SO101. First, we have to move the robot to the position where each joint is in the middle of its range, then we press `Enter`. Secondly, we move all joints through their full range of motion. A video of this same process for the SO101 as reference can be found [here](https://huggingface.co/docs/lerobot/en/so101#calibration-video). + +### Wired version +If you have the **wired** LeKiwi version, please run all commands on your laptop. + +### Calibrate leader arm +Then, to calibrate the leader arm (which is attached to the laptop/pc). Run the following command of API example on your laptop: + + + +```bash +python -m lerobot.calibrate \ + --teleop.type=so100_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot + --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name +``` + + + +```python +from lerobot.common.teleoperators.so100_leader import SO100LeaderConfig, SO100Leader + +config = SO100LeaderConfig( + port="/dev/tty.usbmodem58760431551", + id="my_awesome_leader_arm", +) + +leader = SO100Leader(config) +leader.connect(calibrate=False) +leader.calibrate() +leader.disconnect() +``` + + + +## Teleoperate LeKiwi + +> [!TIP] +> If you're using a Mac, you might need to give Terminal permission to access your keyboard for teleoperation. Go to System Preferences > Security & Privacy > Input Monitoring and check the box for Terminal. + +To teleoperate, SSH into your Raspberry Pi, and run `conda activate lerobot` and this command: +```bash +python -m lerobot.common.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi +``` + +Then on your laptop, also run `conda activate lerobot` and run the API example, make sure you set the correct `remote_ip` and `port` in `examples/lekiwi/teleoperate.py`. + +```bash +python examples/lekiwi/teleoperate.py +``` + +You should see on your laptop something like this: ```[INFO] Connected to remote robot at tcp://172.17.133.91:5555 and video stream at tcp://172.17.133.91:5556.``` Now you can move the leader arm and use the keyboard (w,a,s,d) to drive forward, left, backwards, right. And use (z,x) to turn left or turn right. You can use (r,f) to increase and decrease the speed of the mobile robot. There are three speed modes, see the table below: + +| Speed Mode | Linear Speed (m/s) | Rotation Speed (deg/s) | +| ---------- | ------------------ | ---------------------- | +| Fast | 0.4 | 90 | +| Medium | 0.25 | 60 | +| Slow | 0.1 | 30 | + + +| Key | Action | +| --- | -------------- | +| W | Move forward | +| A | Move left | +| S | Move backward | +| D | Move right | +| Z | Turn left | +| X | Turn right | +| R | Increase speed | +| F | Decrease speed | + +> [!TIP] +> If you use a different keyboard, you can change the keys for each command in the [`LeKiwiConfig`](../lerobot/common/robot_devices/robots/configs.py). + +### Wired version +If you have the **wired** LeKiwi version, please run all commands on your laptop. + +## Record a dataset + +Once you're familiar with teleoperation, you can record your first dataset. + +We use the Hugging Face hub features for uploading your dataset. If you haven't previously used the Hub, make sure you can login via the cli using a write-access token, this token can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens). + +Add your token to the CLI by running this command: +```bash +huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential +``` + +Then store your Hugging Face repository name in a variable: +```bash +HF_USER=$(huggingface-cli whoami | head -n 1) +echo $HF_USER +``` + +Now you can record a dataset. To record episodes and upload your dataset to the hub, execute this API example tailored for LeKiwi. Make sure to first adapt the `remote_ip`, `repo_id`, `port` and `task` in the script. If you would like to run the script for longer you can increase `NB_CYCLES_CLIENT_CONNECTION`. +```bash +python examples/lekiwi/record.py +``` + +#### Dataset upload +Locally, your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}`. At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/so101_test) that you can obtain by running: +```bash +echo https://huggingface.co/datasets/${HF_USER}/so101_test +``` +Your dataset will be automatically tagged with `LeRobot` for the community to find it easily, and you can also add custom tags (in this case `tutorial` for example). + +You can look for other LeRobot datasets on the hub by searching for `LeRobot` [tags](https://huggingface.co/datasets?other=LeRobot). + +#### Tips for gathering data + +Once you're comfortable with data recording, you can create a larger dataset for training. A good starting task is grasping an object at different locations and placing it in a bin. We suggest recording at least 50 episodes, with 10 episodes per location. Keep the cameras fixed and maintain consistent grasping behavior throughout the recordings. Also make sure the object you are manipulating is visible on the camera's. A good rule of thumb is you should be able to do the task yourself by only looking at the camera images. + +In the following sections, you’ll train your neural network. After achieving reliable grasping performance, you can start introducing more variations during data collection, such as additional grasp locations, different grasping techniques, and altering camera positions. + +Avoid adding too much variation too quickly, as it may hinder your results. + +If you want to dive deeper into this important topic, you can check out the [blog post](https://huggingface.co/blog/lerobot-datasets#what-makes-a-good-dataset) we wrote on what makes a good dataset. + +#### Troubleshooting: +- On Linux, if the left and right arrow keys and escape key don't have any effect during data recording, make sure you've set the `$DISPLAY` environment variable. See [pynput limitations](https://pynput.readthedocs.io/en/latest/limitations.html#linux). + + +## Replay an episode + +To replay an episode run the API example below, make sure to change `remote_ip`, `port`, LeRobotDatasetId and episode index. + + +```bash +python examples/lekiwi/replay.py +``` + +Congrats 🎉, your robot is all set to learn a task on its own. Start training it by the training part of this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot) + +## Evaluate your policy + +To evaluate your policy run the `evaluate.py` API example, make sure to change `remote_ip`, `port`, model.. + +```bash +python examples/lekiwi/evaluate.py +``` + +> [!TIP] +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). diff --git a/lerobot/common/robots/lekiwi/lekiwi.py b/lerobot/common/robots/lekiwi/lekiwi.py new file mode 100644 index 000000000..f6a9b8bf1 --- /dev/null +++ b/lerobot/common/robots/lekiwi/lekiwi.py @@ -0,0 +1,411 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from functools import cached_property +from itertools import chain +from typing import Any + +import numpy as np + +from lerobot.common.cameras.utils import make_cameras_from_configs +from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.common.motors.feetech import ( + FeetechMotorsBus, + OperatingMode, +) + +from ..robot import Robot +from ..utils import ensure_safe_goal_position +from .config_lekiwi import LeKiwiConfig + +logger = logging.getLogger(__name__) + + +class LeKiwi(Robot): + """ + The robot includes a three omniwheel mobile base and a remote follower arm. + The leader arm is connected locally (on the laptop) and its joint positions are recorded and then + forwarded to the remote follower arm (after applying a safety clamp). + In parallel, keyboard teleoperation is used to generate raw velocity commands for the wheels. + """ + + config_class = LeKiwiConfig + name = "lekiwi" + + def __init__(self, config: LeKiwiConfig): + super().__init__(config) + self.config = config + norm_mode_body = MotorNormMode.DEGREES if config.use_degrees else MotorNormMode.RANGE_M100_100 + self.bus = FeetechMotorsBus( + port=self.config.port, + motors={ + # arm + "arm_shoulder_pan": Motor(1, "sts3215", norm_mode_body), + "arm_shoulder_lift": Motor(2, "sts3215", norm_mode_body), + "arm_elbow_flex": Motor(3, "sts3215", norm_mode_body), + "arm_wrist_flex": Motor(4, "sts3215", norm_mode_body), + "arm_wrist_roll": Motor(5, "sts3215", norm_mode_body), + "arm_gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100), + # base + "base_left_wheel": Motor(7, "sts3215", MotorNormMode.RANGE_M100_100), + "base_back_wheel": Motor(8, "sts3215", MotorNormMode.RANGE_M100_100), + "base_right_wheel": Motor(9, "sts3215", MotorNormMode.RANGE_M100_100), + }, + calibration=self.calibration, + ) + self.arm_motors = [motor for motor in self.bus.motors if motor.startswith("arm")] + self.base_motors = [motor for motor in self.bus.motors if motor.startswith("base")] + self.cameras = make_cameras_from_configs(config.cameras) + + @property + def _state_ft(self) -> dict[str, type]: + return dict.fromkeys( + ( + "arm_shoulder_pan.pos", + "arm_shoulder_lift.pos", + "arm_elbow_flex.pos", + "arm_wrist_flex.pos", + "arm_wrist_roll.pos", + "arm_gripper.pos", + "x.vel", + "y.vel", + "theta.vel", + ), + float, + ) + + @property + def _cameras_ft(self) -> dict[str, tuple]: + return { + cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras + } + + @cached_property + def observation_features(self) -> dict[str, type | tuple]: + return {**self._state_ft, **self._cameras_ft} + + @cached_property + def action_features(self) -> dict[str, type]: + return self._state_ft + + @property + def is_connected(self) -> bool: + return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + + def connect(self, calibrate: bool = True) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + self.bus.connect() + if not self.is_calibrated and calibrate: + self.calibrate() + + for cam in self.cameras.values(): + cam.connect() + + self.configure() + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + return self.bus.is_calibrated + + def calibrate(self) -> None: + logger.info(f"\nRunning calibration of {self}") + + motors = self.arm_motors + self.base_motors + + self.bus.disable_torque(self.arm_motors) + for name in self.arm_motors: + self.bus.write("Operating_Mode", name, OperatingMode.POSITION.value) + + input("Move robot to the middle of its range of motion and press ENTER....") + homing_offsets = self.bus.set_half_turn_homings(self.arm_motors) + + homing_offsets.update(dict.fromkeys(self.base_motors, 0)) + + full_turn_motor = [ + motor for motor in motors if any(keyword in motor for keyword in ["wheel", "wrist"]) + ] + unknown_range_motors = [motor for motor in motors if motor not in full_turn_motor] + + print( + f"Move all arm joints except '{full_turn_motor}' sequentially through their " + "entire ranges of motion.\nRecording positions. Press ENTER to stop..." + ) + range_mins, range_maxes = self.bus.record_ranges_of_motion(unknown_range_motors) + for name in full_turn_motor: + range_mins[name] = 0 + range_maxes[name] = 4095 + + self.calibration = {} + for name, motor in self.bus.motors.items(): + self.calibration[name] = MotorCalibration( + id=motor.id, + drive_mode=0, + homing_offset=homing_offsets[name], + range_min=range_mins[name], + range_max=range_maxes[name], + ) + + self.bus.write_calibration(self.calibration) + self._save_calibration() + print("Calibration saved to", self.calibration_fpath) + + def configure(self): + # Set-up arm actuators (position mode) + # We assume that at connection time, arm is in a rest position, + # and torque can be safely disabled to run calibration. + self.bus.disable_torque() + self.bus.configure_motors() + for name in self.arm_motors: + self.bus.write("Operating_Mode", name, OperatingMode.POSITION.value) + # Set P_Coefficient to lower value to avoid shakiness (Default is 32) + self.bus.write("P_Coefficient", name, 16) + # Set I_Coefficient and D_Coefficient to default value 0 and 32 + self.bus.write("I_Coefficient", name, 0) + self.bus.write("D_Coefficient", name, 32) + + for name in self.base_motors: + self.bus.write("Operating_Mode", name, OperatingMode.VELOCITY.value) + + self.bus.enable_torque() + + def setup_motors(self) -> None: + for motor in chain(reversed(self.arm_motors), reversed(self.base_motors)): + input(f"Connect the controller board to the '{motor}' motor only and press enter.") + self.bus.setup_motor(motor) + print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + + @staticmethod + def _degps_to_raw(degps: float) -> int: + steps_per_deg = 4096.0 / 360.0 + speed_in_steps = degps * steps_per_deg + speed_int = int(round(speed_in_steps)) + # Cap the value to fit within signed 16-bit range (-32768 to 32767) + if speed_int > 0x7FFF: + speed_int = 0x7FFF # 32767 -> maximum positive value + elif speed_int < -0x8000: + speed_int = -0x8000 # -32768 -> minimum negative value + return speed_int + + @staticmethod + def _raw_to_degps(raw_speed: int) -> float: + steps_per_deg = 4096.0 / 360.0 + magnitude = raw_speed + degps = magnitude / steps_per_deg + return degps + + def _body_to_wheel_raw( + self, + x: float, + y: float, + theta: float, + wheel_radius: float = 0.05, + base_radius: float = 0.125, + max_raw: int = 3000, + ) -> dict: + """ + Convert desired body-frame velocities into wheel raw commands. + + Parameters: + x_cmd : Linear velocity in x (m/s). + y_cmd : Linear velocity in y (m/s). + theta_cmd : Rotational velocity (deg/s). + wheel_radius: Radius of each wheel (meters). + base_radius : Distance from the center of rotation to each wheel (meters). + max_raw : Maximum allowed raw command (ticks) per wheel. + + Returns: + A dictionary with wheel raw commands: + {"base_left_wheel": value, "base_back_wheel": value, "base_right_wheel": value}. + + Notes: + - Internally, the method converts theta_cmd to rad/s for the kinematics. + - The raw command is computed from the wheels angular speed in deg/s + using _degps_to_raw(). If any command exceeds max_raw, all commands + are scaled down proportionally. + """ + # Convert rotational velocity from deg/s to rad/s. + theta_rad = theta * (np.pi / 180.0) + # Create the body velocity vector [x, y, theta_rad]. + velocity_vector = np.array([x, y, theta_rad]) + + # Define the wheel mounting angles with a -90° offset. + angles = np.radians(np.array([240, 0, 120]) - 90) + # Build the kinematic matrix: each row maps body velocities to a wheel’s linear speed. + # The third column (base_radius) accounts for the effect of rotation. + m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles]) + + # Compute each wheel’s linear speed (m/s) and then its angular speed (rad/s). + wheel_linear_speeds = m.dot(velocity_vector) + wheel_angular_speeds = wheel_linear_speeds / wheel_radius + + # Convert wheel angular speeds from rad/s to deg/s. + wheel_degps = wheel_angular_speeds * (180.0 / np.pi) + + # Scaling + steps_per_deg = 4096.0 / 360.0 + raw_floats = [abs(degps) * steps_per_deg for degps in wheel_degps] + max_raw_computed = max(raw_floats) + if max_raw_computed > max_raw: + scale = max_raw / max_raw_computed + wheel_degps = wheel_degps * scale + + # Convert each wheel’s angular speed (deg/s) to a raw integer. + wheel_raw = [self._degps_to_raw(deg) for deg in wheel_degps] + + return { + "base_left_wheel": wheel_raw[0], + "base_back_wheel": wheel_raw[1], + "base_right_wheel": wheel_raw[2], + } + + def _wheel_raw_to_body( + self, + left_wheel_speed, + back_wheel_speed, + right_wheel_speed, + wheel_radius: float = 0.05, + base_radius: float = 0.125, + ) -> dict[str, Any]: + """ + Convert wheel raw command feedback back into body-frame velocities. + + Parameters: + wheel_raw : Vector with raw wheel commands ("base_left_wheel", "base_back_wheel", "base_right_wheel"). + wheel_radius: Radius of each wheel (meters). + base_radius : Distance from the robot center to each wheel (meters). + + Returns: + A dict (x.vel, y.vel, theta.vel) all in m/s + """ + + # Convert each raw command back to an angular speed in deg/s. + wheel_degps = np.array( + [ + self._raw_to_degps(left_wheel_speed), + self._raw_to_degps(back_wheel_speed), + self._raw_to_degps(right_wheel_speed), + ] + ) + + # Convert from deg/s to rad/s. + wheel_radps = wheel_degps * (np.pi / 180.0) + # Compute each wheel’s linear speed (m/s) from its angular speed. + wheel_linear_speeds = wheel_radps * wheel_radius + + # Define the wheel mounting angles with a -90° offset. + angles = np.radians(np.array([240, 0, 120]) - 90) + m = np.array([[np.cos(a), np.sin(a), base_radius] for a in angles]) + + # Solve the inverse kinematics: body_velocity = M⁻¹ · wheel_linear_speeds. + m_inv = np.linalg.inv(m) + velocity_vector = m_inv.dot(wheel_linear_speeds) + x, y, theta_rad = velocity_vector + theta = theta_rad * (180.0 / np.pi) + return { + "x.vel": x, + "y.vel": y, + "theta.vel": theta, + } # m/s and deg/s + + def get_observation(self) -> dict[str, Any]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + # Read actuators position for arm and vel for base + start = time.perf_counter() + arm_pos = self.bus.sync_read("Present_Position", self.arm_motors) + base_wheel_vel = self.bus.sync_read("Present_Velocity", self.base_motors) + + base_vel = self._wheel_raw_to_body( + base_wheel_vel["base_left_wheel"], + base_wheel_vel["base_back_wheel"], + base_wheel_vel["base_right_wheel"], + ) + + arm_state = {f"{k}.pos": v for k, v in arm_pos.items()} + + obs_dict = {**arm_state, **base_vel} + + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read state: {dt_ms:.1f}ms") + + # Capture images from cameras + for cam_key, cam in self.cameras.items(): + start = time.perf_counter() + obs_dict[cam_key] = cam.async_read() + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") + + return obs_dict + + def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + """Command lekiwi to move to a target joint configuration. + + The relative action magnitude may be clipped depending on the configuration parameter + `max_relative_target`. In this case, the action sent differs from original action. + Thus, this function always returns the action actually sent. + + Raises: + RobotDeviceNotConnectedError: if robot is not connected. + + Returns: + np.ndarray: the action sent to the motors, potentially clipped. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + arm_goal_pos = {k: v for k, v in action.items() if k.endswith(".pos")} + base_goal_vel = {k: v for k, v in action.items() if k.endswith(".vel")} + + base_wheel_goal_vel = self._body_to_wheel_raw( + base_goal_vel["x.vel"], base_goal_vel["y.vel"], base_goal_vel["theta.vel"] + ) + + # Cap goal position when too far away from present position. + # /!\ Slower fps expected due to reading from the follower. + if self.config.max_relative_target is not None: + present_pos = self.bus.sync_read("Present_Position", self.arm_motors) + goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in arm_goal_pos.items()} + arm_safe_goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target) + arm_goal_pos = arm_safe_goal_pos + + # Send goal position to the actuators + arm_goal_pos_raw = {k.replace(".pos", ""): v for k, v in arm_goal_pos.items()} + self.bus.sync_write("Goal_Position", arm_goal_pos_raw) + self.bus.sync_write("Goal_Velocity", base_wheel_goal_vel) + + return {**arm_goal_pos, **base_goal_vel} + + def stop_base(self): + self.bus.sync_write("Goal_Velocity", dict.fromkeys(self.base_motors, 0), num_retry=5) + logger.info("Base motors stopped") + + def disconnect(self): + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + self.stop_base() + self.bus.disconnect(self.config.disable_torque_on_disconnect) + for cam in self.cameras.values(): + cam.disconnect() + + logger.info(f"{self} disconnected.") diff --git a/lerobot/common/robots/lekiwi/lekiwi_client.py b/lerobot/common/robots/lekiwi/lekiwi_client.py new file mode 100644 index 000000000..f79b7f81a --- /dev/null +++ b/lerobot/common/robots/lekiwi/lekiwi_client.py @@ -0,0 +1,342 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO(aliberts, Steven, Pepijn): use gRPC calls instead of zmq? + +import base64 +import json +import logging +from functools import cached_property +from typing import Any, Dict, Optional, Tuple + +import cv2 +import numpy as np +import torch +import zmq + +from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +from ..robot import Robot +from .config_lekiwi import LeKiwiClientConfig + + +class LeKiwiClient(Robot): + config_class = LeKiwiClientConfig + name = "lekiwi_client" + + def __init__(self, config: LeKiwiClientConfig): + super().__init__(config) + self.config = config + self.id = config.id + self.robot_type = config.type + + self.remote_ip = config.remote_ip + self.port_zmq_cmd = config.port_zmq_cmd + self.port_zmq_observations = config.port_zmq_observations + + self.teleop_keys = config.teleop_keys + + self.polling_timeout_ms = config.polling_timeout_ms + self.connect_timeout_s = config.connect_timeout_s + + self.zmq_context = None + self.zmq_cmd_socket = None + self.zmq_observation_socket = None + + self.last_frames = {} + + self.last_remote_state = {} + + # Define three speed levels and a current index + self.speed_levels = [ + {"xy": 0.1, "theta": 30}, # slow + {"xy": 0.2, "theta": 60}, # medium + {"xy": 0.3, "theta": 90}, # fast + ] + self.speed_index = 0 # Start at slow + + self._is_connected = False + self.logs = {} + + @cached_property + def _state_ft(self) -> dict[str, type]: + return dict.fromkeys( + ( + "arm_shoulder_pan.pos", + "arm_shoulder_lift.pos", + "arm_elbow_flex.pos", + "arm_wrist_flex.pos", + "arm_wrist_roll.pos", + "arm_gripper.pos", + "x.vel", + "y.vel", + "theta.vel", + ), + float, + ) + + @cached_property + def _state_order(self) -> tuple[str, ...]: + return tuple(self._state_ft.keys()) + + @cached_property + def _cameras_ft(self) -> dict[str, tuple[int, int, int]]: + return {name: (cfg.height, cfg.width, 3) for name, cfg in self.config.cameras.items()} + + @cached_property + def observation_features(self) -> dict[str, type | tuple]: + return {**self._state_ft, **self._cameras_ft} + + @cached_property + def action_features(self) -> dict[str, type]: + return self._state_ft + + @property + def is_connected(self) -> bool: + return self._is_connected + + @property + def is_calibrated(self) -> bool: + pass + + def connect(self) -> None: + """Establishes ZMQ sockets with the remote mobile robot""" + + if self._is_connected: + raise DeviceAlreadyConnectedError( + "LeKiwi Daemon is already connected. Do not run `robot.connect()` twice." + ) + + self.zmq_context = zmq.Context() + self.zmq_cmd_socket = self.zmq_context.socket(zmq.PUSH) + zmq_cmd_locator = f"tcp://{self.remote_ip}:{self.port_zmq_cmd}" + self.zmq_cmd_socket.connect(zmq_cmd_locator) + self.zmq_cmd_socket.setsockopt(zmq.CONFLATE, 1) + + self.zmq_observation_socket = self.zmq_context.socket(zmq.PULL) + zmq_observations_locator = f"tcp://{self.remote_ip}:{self.port_zmq_observations}" + self.zmq_observation_socket.connect(zmq_observations_locator) + self.zmq_observation_socket.setsockopt(zmq.CONFLATE, 1) + + poller = zmq.Poller() + poller.register(self.zmq_observation_socket, zmq.POLLIN) + socks = dict(poller.poll(self.connect_timeout_s * 1000)) + if self.zmq_observation_socket not in socks or socks[self.zmq_observation_socket] != zmq.POLLIN: + raise DeviceNotConnectedError("Timeout waiting for LeKiwi Host to connect expired.") + + self._is_connected = True + + def calibrate(self) -> None: + pass + + def _poll_and_get_latest_message(self) -> Optional[str]: + """Polls the ZMQ socket for a limited time and returns the latest message string.""" + poller = zmq.Poller() + poller.register(self.zmq_observation_socket, zmq.POLLIN) + + try: + socks = dict(poller.poll(self.polling_timeout_ms)) + except zmq.ZMQError as e: + logging.error(f"ZMQ polling error: {e}") + return None + + if self.zmq_observation_socket not in socks: + logging.info("No new data available within timeout.") + return None + + last_msg = None + while True: + try: + msg = self.zmq_observation_socket.recv_string(zmq.NOBLOCK) + last_msg = msg + except zmq.Again: + break + + if last_msg is None: + logging.warning("Poller indicated data, but failed to retrieve message.") + + return last_msg + + def _parse_observation_json(self, obs_string: str) -> Optional[Dict[str, Any]]: + """Parses the JSON observation string.""" + try: + return json.loads(obs_string) + except json.JSONDecodeError as e: + logging.error(f"Error decoding JSON observation: {e}") + return None + + def _decode_image_from_b64(self, image_b64: str) -> Optional[np.ndarray]: + """Decodes a base64 encoded image string to an OpenCV image.""" + if not image_b64: + return None + try: + jpg_data = base64.b64decode(image_b64) + np_arr = np.frombuffer(jpg_data, dtype=np.uint8) + frame = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) + if frame is None: + logging.warning("cv2.imdecode returned None for an image.") + return frame + except (TypeError, ValueError) as e: + logging.error(f"Error decoding base64 image data: {e}") + return None + + def _remote_state_from_obs( + self, observation: Dict[str, Any] + ) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: + """Extracts frames, and state from the parsed observation.""" + flat_state = {key: value for key, value in observation.items() if key in self._state_ft} + + state_vec = np.array( + [flat_state.get(k, 0.0) for k in self._state_order], + dtype=np.float32, + ) + + # Decode images + image_observation = { + f"observation.images.{key}": value + for key, value in observation.items() + if key in self._cameras_ft + } + current_frames: Dict[str, np.ndarray] = {} + for cam_name, image_b64 in image_observation.items(): + frame = self._decode_image_from_b64(image_b64) + if frame is not None: + current_frames[cam_name] = frame + + return current_frames, {"observation.state": state_vec} + + def _get_data(self) -> Tuple[Dict[str, np.ndarray], Dict[str, Any], Dict[str, Any]]: + """ + Polls the video socket for the latest observation data. + + Attempts to retrieve and decode the latest message within a short timeout. + If successful, updates and returns the new frames, speed, and arm state. + If no new data arrives or decoding fails, returns the last known values. + """ + + # 1. Get the latest message string from the socket + latest_message_str = self._poll_and_get_latest_message() + + # 2. If no message, return cached data + if latest_message_str is None: + return self.last_frames, self.last_remote_state + + # 3. Parse the JSON message + observation = self._parse_observation_json(latest_message_str) + + # 4. If JSON parsing failed, return cached data + if observation is None: + return self.last_frames, self.last_remote_state + + # 5. Process the valid observation data + try: + new_frames, new_state = self._remote_state_from_obs(observation) + except Exception as e: + logging.error(f"Error processing observation data, serving last observation: {e}") + return self.last_frames, self.last_remote_state + + self.last_frames = new_frames + self.last_remote_state = new_state + + return new_frames, new_state + + def get_observation(self) -> dict[str, Any]: + """ + Capture observations from the remote robot: current follower arm positions, + present wheel speeds (converted to body-frame velocities: x, y, theta), + and a camera frame. Receives over ZMQ, translate to body-frame vel + """ + if not self._is_connected: + raise DeviceNotConnectedError("LeKiwiClient is not connected. You need to run `robot.connect()`.") + + frames, obs_dict = self._get_data() + + # Loop over each configured camera + for cam_name, frame in frames.items(): + if frame is None: + logging.warning("Frame is None") + frame = np.zeros((640, 480, 3), dtype=np.uint8) + obs_dict[cam_name] = torch.from_numpy(frame) + + return obs_dict + + def _from_keyboard_to_base_action(self, pressed_keys: np.ndarray): + # Speed control + if self.teleop_keys["speed_up"] in pressed_keys: + self.speed_index = min(self.speed_index + 1, 2) + if self.teleop_keys["speed_down"] in pressed_keys: + self.speed_index = max(self.speed_index - 1, 0) + speed_setting = self.speed_levels[self.speed_index] + xy_speed = speed_setting["xy"] # e.g. 0.1, 0.25, or 0.4 + theta_speed = speed_setting["theta"] # e.g. 30, 60, or 90 + + x_cmd = 0.0 # m/s forward/backward + y_cmd = 0.0 # m/s lateral + theta_cmd = 0.0 # deg/s rotation + + if self.teleop_keys["forward"] in pressed_keys: + x_cmd += xy_speed + if self.teleop_keys["backward"] in pressed_keys: + x_cmd -= xy_speed + if self.teleop_keys["left"] in pressed_keys: + y_cmd += xy_speed + if self.teleop_keys["right"] in pressed_keys: + y_cmd -= xy_speed + if self.teleop_keys["rotate_left"] in pressed_keys: + theta_cmd += theta_speed + if self.teleop_keys["rotate_right"] in pressed_keys: + theta_cmd -= theta_speed + return { + "x.vel": x_cmd, + "y.vel": y_cmd, + "theta.vel": theta_cmd, + } + + def configure(self): + pass + + def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + """Command lekiwi to move to a target joint configuration. Translates to motor space + sends over ZMQ + + Args: + action (np.ndarray): array containing the goal positions for the motors. + + Raises: + RobotDeviceNotConnectedError: if robot is not connected. + + Returns: + np.ndarray: the action sent to the motors, potentially clipped. + """ + if not self._is_connected: + raise DeviceNotConnectedError( + "ManipulatorRobot is not connected. You need to run `robot.connect()`." + ) + + self.zmq_cmd_socket.send_string(json.dumps(action)) # action is in motor space + + # TODO(Steven): Remove the np conversion when it is possible to record a non-numpy array value + actions = np.array([action.get(k, 0.0) for k in self._state_order], dtype=np.float32) + return {"action": actions} + + def disconnect(self): + """Cleans ZMQ comms""" + + if not self._is_connected: + raise DeviceNotConnectedError( + "LeKiwi is not connected. You need to run `robot.connect()` before disconnecting." + ) + self.zmq_observation_socket.close() + self.zmq_cmd_socket.close() + self.zmq_context.term() + self._is_connected = False diff --git a/lerobot/common/robots/lekiwi/lekiwi_host.py b/lerobot/common/robots/lekiwi/lekiwi_host.py new file mode 100644 index 000000000..1155cf71c --- /dev/null +++ b/lerobot/common/robots/lekiwi/lekiwi_host.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import json +import logging +import time + +import cv2 +import zmq + +from .config_lekiwi import LeKiwiConfig, LeKiwiHostConfig +from .lekiwi import LeKiwi + + +class LeKiwiHost: + def __init__(self, config: LeKiwiHostConfig): + self.zmq_context = zmq.Context() + self.zmq_cmd_socket = self.zmq_context.socket(zmq.PULL) + self.zmq_cmd_socket.setsockopt(zmq.CONFLATE, 1) + self.zmq_cmd_socket.bind(f"tcp://*:{config.port_zmq_cmd}") + + self.zmq_observation_socket = self.zmq_context.socket(zmq.PUSH) + self.zmq_observation_socket.setsockopt(zmq.CONFLATE, 1) + self.zmq_observation_socket.bind(f"tcp://*:{config.port_zmq_observations}") + + self.connection_time_s = config.connection_time_s + self.watchdog_timeout_ms = config.watchdog_timeout_ms + self.max_loop_freq_hz = config.max_loop_freq_hz + + def disconnect(self): + self.zmq_observation_socket.close() + self.zmq_cmd_socket.close() + self.zmq_context.term() + + +def main(): + logging.info("Configuring LeKiwi") + robot_config = LeKiwiConfig() + robot = LeKiwi(robot_config) + + logging.info("Connecting LeKiwi") + robot.connect() + + logging.info("Starting HostAgent") + host_config = LeKiwiHostConfig() + host = LeKiwiHost(host_config) + + last_cmd_time = time.time() + watchdog_active = False + logging.info("Waiting for commands...") + try: + # Business logic + start = time.perf_counter() + duration = 0 + while duration < host.connection_time_s: + loop_start_time = time.time() + try: + msg = host.zmq_cmd_socket.recv_string(zmq.NOBLOCK) + data = dict(json.loads(msg)) + _action_sent = robot.send_action(data) + last_cmd_time = time.time() + watchdog_active = False + except zmq.Again: + if not watchdog_active: + logging.warning("No command available") + except Exception as e: + logging.error("Message fetching failed: %s", e) + + now = time.time() + if (now - last_cmd_time > host.watchdog_timeout_ms / 1000) and not watchdog_active: + logging.warning( + f"Command not received for more than {host.watchdog_timeout_ms} milliseconds. Stopping the base." + ) + watchdog_active = True + robot.stop_base() + + last_observation = robot.get_observation() + + # Encode ndarrays to base64 strings + for cam_key, _ in robot.cameras.items(): + ret, buffer = cv2.imencode( + ".jpg", last_observation[cam_key], [int(cv2.IMWRITE_JPEG_QUALITY), 90] + ) + if ret: + last_observation[cam_key] = base64.b64encode(buffer).decode("utf-8") + else: + last_observation[cam_key] = "" + + # Send the observation to the remote agent + try: + host.zmq_observation_socket.send_string(json.dumps(last_observation), flags=zmq.NOBLOCK) + except zmq.Again: + logging.info("Dropping observation, no client connected") + + # Ensure a short sleep to avoid overloading the CPU. + elapsed = time.time() - loop_start_time + + time.sleep(max(1 / host.max_loop_freq_hz - elapsed, 0)) + duration = time.perf_counter() - start + print("Cycle time reached.") + + except KeyboardInterrupt: + print("Keyboard interrupt received. Exiting...") + finally: + print("Shutting down Lekiwi Host.") + robot.disconnect() + host.disconnect() + + logging.info("Finished LeKiwi cleanly") + + +if __name__ == "__main__": + main() diff --git a/lerobot/common/robots/robot.py b/lerobot/common/robots/robot.py new file mode 100644 index 000000000..76c57faf4 --- /dev/null +++ b/lerobot/common/robots/robot.py @@ -0,0 +1,184 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from pathlib import Path +from typing import Any, Type + +import draccus + +from lerobot.common.constants import HF_LEROBOT_CALIBRATION, ROBOTS +from lerobot.common.motors import MotorCalibration + +from .config import RobotConfig + + +# TODO(aliberts): action/obs typing such as Generic[ObsType, ActType] similar to gym.Env ? +# https://github.com/Farama-Foundation/Gymnasium/blob/3287c869f9a48d99454306b0d4b4ec537f0f35e3/gymnasium/core.py#L23 +class Robot(abc.ABC): + """ + The base abstract class for all LeRobot-compatible robots. + + This class provides a standardized interface for interacting with physical robots. + Subclasses must implement all abstract methods and properties to be usable. + + Attributes: + config_class (RobotConfig): The expected configuration class for this robot. + name (str): The unique robot name used to identify this robot type. + """ + + # Set these in ALL subclasses + config_class: Type[RobotConfig] + name: str + + def __init__(self, config: RobotConfig): + self.robot_type = self.name + self.id = config.id + self.calibration_dir = ( + config.calibration_dir if config.calibration_dir else HF_LEROBOT_CALIBRATION / ROBOTS / self.name + ) + self.calibration_dir.mkdir(parents=True, exist_ok=True) + self.calibration_fpath = self.calibration_dir / f"{self.id}.json" + self.calibration: dict[str, MotorCalibration] = {} + if self.calibration_fpath.is_file(): + self._load_calibration() + + def __str__(self) -> str: + return f"{self.id} {self.__class__.__name__}" + + # TODO(aliberts): create a proper Feature class for this that links with datasets + @property + @abc.abstractmethod + def observation_features(self) -> dict: + """ + A dictionary describing the structure and types of the observations produced by the robot. + Its structure (keys) should match the structure of what is returned by :pymeth:`get_observation`. + Values for the dict should either be: + - The type of the value if it's a simple value, e.g. `float` for single proprioceptive value (a joint's position/velocity) + - A tuple representing the shape if it's an array-type value, e.g. `(height, width, channel)` for images + + Note: this property should be able to be called regardless of whether the robot is connected or not. + """ + pass + + @property + @abc.abstractmethod + def action_features(self) -> dict: + """ + A dictionary describing the structure and types of the actions expected by the robot. Its structure + (keys) should match the structure of what is passed to :pymeth:`send_action`. Values for the dict + should be the type of the value if it's a simple value, e.g. `float` for single proprioceptive value + (a joint's goal position/velocity) + + Note: this property should be able to be called regardless of whether the robot is connected or not. + """ + pass + + @property + @abc.abstractmethod + def is_connected(self) -> bool: + """ + Whether the robot is currently connected or not. If `False`, calling :pymeth:`get_observation` or + :pymeth:`send_action` should raise an error. + """ + pass + + @abc.abstractmethod + def connect(self, calibrate: bool = True) -> None: + """ + Establish communication with the robot. + + Args: + calibrate (bool): If True, automatically calibrate the robot after connecting if it's not + calibrated or needs calibration (this is hardware-dependant). + """ + pass + + @property + @abc.abstractmethod + def is_calibrated(self) -> bool: + """Whether the robot is currently calibrated or not. Should be always `True` if not applicable""" + pass + + @abc.abstractmethod + def calibrate(self) -> None: + """ + Calibrate the robot if applicable. If not, this should be a no-op. + + This method should collect any necessary data (e.g., motor offsets) and update the + :pyattr:`calibration` dictionary accordingly. + """ + pass + + def _load_calibration(self, fpath: Path | None = None) -> None: + """ + Helper to load calibration data from the specified file. + + Args: + fpath (Path | None): Optional path to the calibration file. Defaults to `self.calibration_fpath`. + """ + fpath = self.calibration_fpath if fpath is None else fpath + with open(fpath) as f, draccus.config_type("json"): + self.calibration = draccus.load(dict[str, MotorCalibration], f) + + def _save_calibration(self, fpath: Path | None = None) -> None: + """ + Helper to save calibration data to the specified file. + + Args: + fpath (Path | None): Optional path to save the calibration file. Defaults to `self.calibration_fpath`. + """ + fpath = self.calibration_fpath if fpath is None else fpath + with open(fpath, "w") as f, draccus.config_type("json"): + draccus.dump(self.calibration, f, indent=4) + + @abc.abstractmethod + def configure(self) -> None: + """ + Apply any one-time or runtime configuration to the robot. + This may include setting motor parameters, control modes, or initial state. + """ + pass + + @abc.abstractmethod + def get_observation(self) -> dict[str, Any]: + """ + Retrieve the current observation from the robot. + + Returns: + dict[str, Any]: A flat dictionary representing the robot's current sensory state. Its structure + should match :pymeth:`observation_features`. + """ + + pass + + @abc.abstractmethod + def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + """ + Send an action command to the robot. + + Args: + action (dict[str, Any]): Dictionary representing the desired action. Its structure should match + :pymeth:`action_features`. + + Returns: + dict[str, Any]: The action actually sent to the motors potentially clipped or modified, e.g. by + safety limits on velocity. + """ + pass + + @abc.abstractmethod + def disconnect(self) -> None: + """Disconnect from the robot and perform any necessary cleanup.""" + pass diff --git a/lerobot/common/robots/so100_follower/__init__.py b/lerobot/common/robots/so100_follower/__init__.py new file mode 100644 index 000000000..63c3e1c17 --- /dev/null +++ b/lerobot/common/robots/so100_follower/__init__.py @@ -0,0 +1,3 @@ +from .config_so100_follower import SO100FollowerConfig, SO100FollowerEndEffectorConfig +from .so100_follower import SO100Follower +from .so100_follower_end_effector import SO100FollowerEndEffector diff --git a/lerobot/common/robots/so100_follower/config_so100_follower.py b/lerobot/common/robots/so100_follower/config_so100_follower.py new file mode 100644 index 000000000..b76675d26 --- /dev/null +++ b/lerobot/common/robots/so100_follower/config_so100_follower.py @@ -0,0 +1,63 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.common.cameras import CameraConfig + +from ..config import RobotConfig + + +@RobotConfig.register_subclass("so100_follower") +@dataclass +class SO100FollowerConfig(RobotConfig): + # Port to connect to the arm + port: str + + disable_torque_on_disconnect: bool = True + + # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. + # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as + # the number of motors in your follower arms. + max_relative_target: int | None = None + + # cameras + cameras: dict[str, CameraConfig] = field(default_factory=dict) + + # Set to `True` for backward compatibility with previous policies/dataset + use_degrees: bool = False + + +@RobotConfig.register_subclass("so100_follower_end_effector") +@dataclass +class SO100FollowerEndEffectorConfig(SO100FollowerConfig): + """Configuration for the SO100FollowerEndEffector robot.""" + + # Default bounds for the end-effector position (in meters) + end_effector_bounds: dict[str, list[float]] = field( + default_factory=lambda: { + "min": [-1.0, -1.0, -1.0], # min x, y, z + "max": [1.0, 1.0, 1.0], # max x, y, z + } + ) + + max_gripper_pos: float = 50 + + end_effector_step_sizes: dict[str, float] = field( + default_factory=lambda: { + "x": 0.02, + "y": 0.02, + "z": 0.02, + } + ) diff --git a/lerobot/common/robots/so100_follower/so100.mdx b/lerobot/common/robots/so100_follower/so100.mdx new file mode 100644 index 000000000..5443a687b --- /dev/null +++ b/lerobot/common/robots/so100_follower/so100.mdx @@ -0,0 +1,489 @@ +# SO-100 + +In the steps below, we explain how to assemble the SO-100 robot. + +## Source the parts + +Follow this [README](https://github.com/TheRobotStudio/SO-ARM100/blob/main/SO100.md). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts. And advise if it's your first time printing or if you don't own a 3D printer. + +## Install LeRobot 🤗 + +To install LeRobot, follow our [Installation Guide](./installation) + +In addition to these instructions, you need to install the Feetech SDK: +```bash +pip install -e ".[feetech]" +``` + +## Configure the motors + +**Note:** +Unlike the SO-101, the motor connectors are not easily accessible once the arm is assembled, so the configuration step must be done beforehand. + +### 1. Find the USB ports associated with each arm + +To find the port for each bus servo adapter, run this script: +```bash +python lerobot/find_port.py +``` + + + + +Example output: + +``` +Finding all available ports for the MotorBus. +['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] +Remove the USB cable from your MotorsBus and press Enter when done. + +[...Disconnect corresponding leader or follower arm and press Enter...] + +The port of this MotorsBus is /dev/tty.usbmodem575E0032081 +Reconnect the USB cable. +``` + +Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm. + + + + +On Linux, you might need to give access to the USB ports by running: +```bash +sudo chmod 666 /dev/ttyACM0 +sudo chmod 666 /dev/ttyACM1 +``` + +Example output: + +``` +Finding all available ports for the MotorBus. +['/dev/ttyACM0', '/dev/ttyACM1'] +Remove the usb cable from your MotorsBus and press Enter when done. + +[...Disconnect corresponding leader or follower arm and press Enter...] + +The port of this MotorsBus is /dev/ttyACM1 +Reconnect the USB cable. +``` + +Where the found port is: `/dev/ttyACM1` corresponding to your leader or follower arm. + + + + +### 2. Set the motors ids and baudrates + +Each motor is identified by a unique id on the bus. When brand new, motors usually come with a default id of `1`. For the communication to work properly between the motors and the controller, we first need to set a unique, different id to each motor. Additionally, the speed at which data is transmitted on the bus is determined by the baudrate. In order to talk to each other, the controller and all the motors need to be configured with the same baudrate. + +To that end, we first need to connect to each motor individually with the controller in order to set these. Since we will write these parameters in the non-volatile section of the motors' internal memory (EEPROM), we'll only need to do this once. + +If you are repurposing motors from another robot, you will probably also need to perform this step as the ids and baudrate likely won't match. + +#### Follower + +Connect the usb cable from your computer and the power supply to the follower arm's controller board. Then, run the following command or run the API example with the port you got from the previous step. You'll also need to give your leader arm a name with the `id` parameter. + +For a visual reference on how to set the motor ids please refer to [this video](https://huggingface.co/docs/lerobot/en/so101#setup-motors-video) where we follow the process for the SO101 arm. + + + + +```bash +python -m lerobot.setup_motors \ + --robot.type=so100_follower \ + --robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step +``` + + + +```python +from lerobot.common.robots.so100_follower import SO100Follower, SO100FollowerConfig + +config = SO100FollowerConfig( + port="/dev/tty.usbmodem585A0076841", + id="my_awesome_follower_arm", +) +follower = SO100Follower(config) +follower.setup_motors() +``` + + + +You should see the following instruction +``` +Connect the controller board to the 'gripper' motor only and press enter. +``` + +As instructed, plug the gripper's motor. Make sure it's the only motor connected to the board, and that the motor itself is not yet daisy-chained to any other motor. As you press `[Enter]`, the script will automatically set the id and baudrate for that motor. + +
+Troubleshooting + + If you get an error at that point, check your cables and make sure they are plugged in properly: +
    +
  • Power supply
  • +
  • USB cable between your computer and the controller board
  • +
  • The 3-pin cable from the controller board to the motor
  • +
+ +If you are using a Waveshare controller board, make sure that the two jumpers are set on the `B` channel (USB). +
+ +You should then see the following message: +``` +'gripper' motor id set to 6 +``` + +Followed by the next instruction: +``` +Connect the controller board to the 'wrist_roll' motor only and press enter. +``` + +You can disconnect the 3-pin cable from the controller board, but you can leave it connected to the gripper motor on the other end, as it will already be in the right place. Now, plug in another 3-pin cable to the wrist roll motor and connect it to the controller board. As with the previous motor, make sure it is the only motor connected to the board and that the motor itself isn't connected to any other one. + +Repeat the operation for each motor as instructed. + +> [!TIP] +> Check your cabling at each step before pressing Enter. For instance, the power supply cable might disconnect as you manipulate the board. + +When you are done, the script will simply finish, at which point the motors are ready to be used. You can now plug the 3-pin cable from each motor to the next one, and the cable from the first motor (the 'shoulder pan' with id=1) to the controller board, which can now be attached to the base of the arm. + +#### Leader +Do the same steps for the leader arm. + + + +```bash +python -m lerobot.setup_motors \ + --teleop.type=so100_leader \ + --teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step +``` + + + +```python +from lerobot.common.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig + +config = SO100LeaderConfig( + port="/dev/tty.usbmodem585A0076841", + id="my_awesome_leader_arm", +) +leader = SO100Leader(config) +leader.setup_motors() +``` + + + +## Step-by-Step Assembly Instructions + +## Remove the gears of the 6 leader motors + +
+Video removing gears + +
+ +
+ +
+ +Follow the video for removing gears. You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm. + +### Clean Parts +Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material. + +### Additional Guidance + +
+Video assembling arms + +
+ +
+ +
+ +**Note:** +This video provides visual guidance for assembling the arms, but it doesn't specify when or how to do the wiring. Inserting the cables beforehand is much easier than doing it afterward. The first arm may take a bit more than 1 hour to assemble, but once you get used to it, you can assemble the second arm in under 1 hour. + +--- + +### First Motor + +**Step 2: Insert Wires** +- Insert two wires into the first motor. + + + +**Step 3: Install in Base** +- Place the first motor into the base. + + + +**Step 4: Secure Motor** +- Fasten the motor with 4 screws. Two from the bottom and two from top. + +**Step 5: Attach Motor Holder** +- Slide over the first motor holder and fasten it using two screws (one on each side). + + + +**Step 6: Attach Motor Horns** +- Install both motor horns, securing the top horn with a screw. Try not to move the motor position when attaching the motor horn, especially for the leader arms, where we removed the gears. + + + +
+ Video adding motor horn + +
+ +**Step 7: Attach Shoulder Part** +- Route one wire to the back of the robot and the other to the left or towards you (see photo). +- Attach the shoulder part. + + + +**Step 8: Secure Shoulder** +- Tighten the shoulder part with 4 screws on top and 4 on the bottom +*(access bottom holes by turning the shoulder).* + +--- + +### Second Motor Assembly + +**Step 9: Install Motor 2** +- Slide the second motor in from the top and link the wire from motor 1 to motor 2. + + + +**Step 10: Attach Shoulder Holder** +- Add the shoulder motor holder. +- Ensure the wire from motor 1 to motor 2 goes behind the holder while the other wire is routed upward (see photo). +- This part can be tight to assemble, you can use a workbench like the image or a similar setup to push the part around the motor. + +
+ + + +
+ +**Step 11: Secure Motor 2** +- Fasten the second motor with 4 screws. + +**Step 12: Attach Motor Horn** +- Attach both motor horns to motor 2, again use the horn screw. + +**Step 13: Attach Base** +- Install the base attachment using 2 screws. + + + +**Step 14: Attach Upper Arm** +- Attach the upper arm with 4 screws on each side. + + + +--- + +### Third Motor Assembly + +**Step 15: Install Motor 3** +- Route the motor cable from motor 2 through the cable holder to motor 3, then secure motor 3 with 4 screws. + +**Step 16: Attach Motor Horn** +- Attach both motor horns to motor 3 and secure one again with a horn screw. + + + +**Step 17: Attach Forearm** +- Connect the forearm to motor 3 using 4 screws on each side. + + + +--- + +### Fourth Motor Assembly + +**Step 18: Install Motor 4** +- Slide in motor 4, attach the cable from motor 3, and secure the cable in its holder with a screw. + +
+ + +
+ +**Step 19: Attach Motor Holder 4** +- Install the fourth motor holder (a tight fit). Ensure one wire is routed upward and the wire from motor 3 is routed downward (see photo). + + + +**Step 20: Secure Motor 4 & Attach Horn** +- Fasten motor 4 with 4 screws and attach its motor horns, use for one a horn screw. + + + +--- + +### Wrist Assembly + +**Step 21: Install Motor 5** +- Insert motor 5 into the wrist holder and secure it with 2 front screws. + + + +**Step 22: Attach Wrist** +- Connect the wire from motor 4 to motor 5. And already insert the other wire for the gripper. +- Secure the wrist to motor 4 using 4 screws on both sides. + + + +**Step 23: Attach Wrist Horn** +- Install only one motor horn on the wrist motor and secure it with a horn screw. + + + +--- + +### Follower Configuration + +**Step 24: Attach Gripper** +- Attach the gripper to motor 5. + + + +**Step 25: Install Gripper Motor** +- Insert the gripper motor, connect the motor wire from motor 5 to motor 6, and secure it with 3 screws on each side. + + + +**Step 26: Attach Gripper Horn & Claw** +- Attach the motor horns and again use a horn screw. +- Install the gripper claw and secure it with 4 screws on both sides. + + + +**Step 27: Mount Controller** +- Attach the motor controller to the back of the robot. + +
+ + +
+ +*Assembly complete – proceed to Leader arm assembly.* + +--- + +### Leader Configuration + +For the leader configuration, perform **Steps 1–23**. Make sure that you removed the motor gears from the motors. + +**Step 24: Attach Leader Holder** +- Mount the leader holder onto the wrist and secure it with a screw. + + + +**Step 25: Attach Handle** +- Attach the handle to motor 5 using 4 screws. + + + +**Step 26: Install Gripper Motor** +- Insert the gripper motor, secure it with 3 screws on each side, attach a motor horn using a horn screw, and connect the motor wire. + + + +**Step 27: Attach Trigger** +- Attach the follower trigger with 4 screws. + + + +**Step 28: Mount Controller** +- Attach the motor controller to the back of the robot. + +
+ + +
+ +## Calibrate + +Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. +The calibration process is very important because it allows a neural network trained on one robot to work on another. + +#### Follower + +Run the following command or API example to calibrate the follower arm: + + + + +```bash +python -m lerobot.calibrate \ + --robot.type=so100_follower \ + --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot + --robot.id=my_awesome_follower_arm # <- Give the robot a unique name +``` + + + +```python +from lerobot.common.robots.so100_follower import SO100FollowerConfig, SO100Follower + +config = SO100FollowerConfig( + port="/dev/tty.usbmodem585A0076891", + id="my_awesome_follower_arm", +) + +follower = SO100Follower(config) +follower.connect(calibrate=False) +follower.calibrate() +follower.disconnect() +``` + + + +We unified the calibration method for most robots. Thus, the calibration steps for this SO100 arm are the same as the steps for the Koch and SO101. First, we have to move the robot to the position where each joint is in the middle of its range, then we press `Enter`. Secondly, we move all joints through their full range of motion. A video of this same process for the SO101 as reference can be found [here](https://huggingface.co/docs/lerobot/en/so101#calibration-video) + +#### Leader + +Do the same steps to calibrate the leader arm, run the following command or API example: + + + + +```bash +python -m lerobot.calibrate \ + --teleop.type=so100_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot + --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name +``` + + + +```python +from lerobot.common.teleoperators.so100_leader import SO100LeaderConfig, SO100Leader + +config = SO100LeaderConfig( + port="/dev/tty.usbmodem58760431551", + id="my_awesome_leader_arm", +) + +leader = SO100Leader(config) +leader.connect(calibrate=False) +leader.calibrate() +leader.disconnect() +``` + + + +Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot) + +> [!TIP] +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). diff --git a/lerobot/common/robots/so100_follower/so100_follower.py b/lerobot/common/robots/so100_follower/so100_follower.py new file mode 100644 index 000000000..952049940 --- /dev/null +++ b/lerobot/common/robots/so100_follower/so100_follower.py @@ -0,0 +1,214 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from functools import cached_property +from typing import Any + +from lerobot.common.cameras.utils import make_cameras_from_configs +from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.common.motors.feetech import ( + FeetechMotorsBus, + OperatingMode, +) + +from ..robot import Robot +from ..utils import ensure_safe_goal_position +from .config_so100_follower import SO100FollowerConfig + +logger = logging.getLogger(__name__) + + +class SO100Follower(Robot): + """ + [SO-100 Follower Arm](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio + """ + + config_class = SO100FollowerConfig + name = "so100_follower" + + def __init__(self, config: SO100FollowerConfig): + super().__init__(config) + self.config = config + norm_mode_body = MotorNormMode.DEGREES if config.use_degrees else MotorNormMode.RANGE_M100_100 + self.bus = FeetechMotorsBus( + port=self.config.port, + motors={ + "shoulder_pan": Motor(1, "sts3215", norm_mode_body), + "shoulder_lift": Motor(2, "sts3215", norm_mode_body), + "elbow_flex": Motor(3, "sts3215", norm_mode_body), + "wrist_flex": Motor(4, "sts3215", norm_mode_body), + "wrist_roll": Motor(5, "sts3215", norm_mode_body), + "gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100), + }, + calibration=self.calibration, + ) + self.cameras = make_cameras_from_configs(config.cameras) + + @property + def _motors_ft(self) -> dict[str, type]: + return {f"{motor}.pos": float for motor in self.bus.motors} + + @property + def _cameras_ft(self) -> dict[str, tuple]: + return { + cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras + } + + @cached_property + def observation_features(self) -> dict[str, type | tuple]: + return {**self._motors_ft, **self._cameras_ft} + + @cached_property + def action_features(self) -> dict[str, type]: + return self._motors_ft + + @property + def is_connected(self) -> bool: + return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + + def connect(self, calibrate: bool = True) -> None: + """ + We assume that at connection time, arm is in a rest position, + and torque can be safely disabled to run calibration. + """ + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + self.bus.connect() + if not self.is_calibrated and calibrate: + self.calibrate() + + for cam in self.cameras.values(): + cam.connect() + + self.configure() + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + return self.bus.is_calibrated + + def calibrate(self) -> None: + logger.info(f"\nRunning calibration of {self}") + self.bus.disable_torque() + for motor in self.bus.motors: + self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value) + + input(f"Move {self} to the middle of its range of motion and press ENTER....") + homing_offsets = self.bus.set_half_turn_homings() + + full_turn_motor = "wrist_roll" + unknown_range_motors = [motor for motor in self.bus.motors if motor != full_turn_motor] + print( + f"Move all joints except '{full_turn_motor}' sequentially through their " + "entire ranges of motion.\nRecording positions. Press ENTER to stop..." + ) + range_mins, range_maxes = self.bus.record_ranges_of_motion(unknown_range_motors) + range_mins[full_turn_motor] = 0 + range_maxes[full_turn_motor] = 4095 + + self.calibration = {} + for motor, m in self.bus.motors.items(): + self.calibration[motor] = MotorCalibration( + id=m.id, + drive_mode=0, + homing_offset=homing_offsets[motor], + range_min=range_mins[motor], + range_max=range_maxes[motor], + ) + + self.bus.write_calibration(self.calibration) + self._save_calibration() + print("Calibration saved to", self.calibration_fpath) + + def configure(self) -> None: + with self.bus.torque_disabled(): + self.bus.configure_motors() + for motor in self.bus.motors: + self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value) + # Set P_Coefficient to lower value to avoid shakiness (Default is 32) + self.bus.write("P_Coefficient", motor, 16) + # Set I_Coefficient and D_Coefficient to default value 0 and 32 + self.bus.write("I_Coefficient", motor, 0) + self.bus.write("D_Coefficient", motor, 32) + + def setup_motors(self) -> None: + for motor in reversed(self.bus.motors): + input(f"Connect the controller board to the '{motor}' motor only and press enter.") + self.bus.setup_motor(motor) + print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + + def get_observation(self) -> dict[str, Any]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + # Read arm position + start = time.perf_counter() + obs_dict = self.bus.sync_read("Present_Position") + obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()} + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read state: {dt_ms:.1f}ms") + + # Capture images from cameras + for cam_key, cam in self.cameras.items(): + start = time.perf_counter() + obs_dict[cam_key] = cam.async_read() + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") + + return obs_dict + + def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + """Command arm to move to a target joint configuration. + + The relative action magnitude may be clipped depending on the configuration parameter + `max_relative_target`. In this case, the action sent differs from original action. + Thus, this function always returns the action actually sent. + + Raises: + RobotDeviceNotConnectedError: if robot is not connected. + + Returns: + the action sent to the motors, potentially clipped. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} + + # Cap goal position when too far away from present position. + # /!\ Slower fps expected due to reading from the follower. + if self.config.max_relative_target is not None: + present_pos = self.bus.sync_read("Present_Position") + goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()} + goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target) + + # Send goal position to the arm + self.bus.sync_write("Goal_Position", goal_pos) + return {f"{motor}.pos": val for motor, val in goal_pos.items()} + + def disconnect(self): + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + self.bus.disconnect(self.config.disable_torque_on_disconnect) + for cam in self.cameras.values(): + cam.disconnect() + + logger.info(f"{self} disconnected.") diff --git a/lerobot/common/robots/so100_follower/so100_follower_end_effector.py b/lerobot/common/robots/so100_follower/so100_follower_end_effector.py new file mode 100644 index 000000000..82e89305b --- /dev/null +++ b/lerobot/common/robots/so100_follower/so100_follower_end_effector.py @@ -0,0 +1,193 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from typing import Any + +import numpy as np + +from lerobot.common.cameras import make_cameras_from_configs +from lerobot.common.errors import DeviceNotConnectedError +from lerobot.common.model.kinematics import RobotKinematics +from lerobot.common.motors import Motor, MotorNormMode +from lerobot.common.motors.feetech import FeetechMotorsBus + +from . import SO100Follower +from .config_so100_follower import SO100FollowerEndEffectorConfig + +logger = logging.getLogger(__name__) +EE_FRAME = "gripper_tip" + + +class SO100FollowerEndEffector(SO100Follower): + """ + SO100Follower robot with end-effector space control. + + This robot inherits from SO100Follower but transforms actions from + end-effector space to joint space before sending them to the motors. + """ + + config_class = SO100FollowerEndEffectorConfig + name = "so100_follower_end_effector" + + def __init__(self, config: SO100FollowerEndEffectorConfig): + super().__init__(config) + self.bus = FeetechMotorsBus( + port=self.config.port, + motors={ + "shoulder_pan": Motor(1, "sts3215", MotorNormMode.DEGREES), + "shoulder_lift": Motor(2, "sts3215", MotorNormMode.DEGREES), + "elbow_flex": Motor(3, "sts3215", MotorNormMode.DEGREES), + "wrist_flex": Motor(4, "sts3215", MotorNormMode.DEGREES), + "wrist_roll": Motor(5, "sts3215", MotorNormMode.DEGREES), + "gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100), + }, + calibration=self.calibration, + ) + + self.cameras = make_cameras_from_configs(config.cameras) + + self.config = config + + # Initialize the kinematics module for the so100 robot + self.kinematics = RobotKinematics(robot_type="so_new_calibration") + + # Store the bounds for end-effector position + self.end_effector_bounds = self.config.end_effector_bounds + + self.current_ee_pos = None + self.current_joint_pos = None + + @property + def action_features(self) -> dict[str, Any]: + """ + Define action features for end-effector control. + Returns dictionary with dtype, shape, and names. + """ + return { + "dtype": "float32", + "shape": (4,), + "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3}, + } + + def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + """ + Transform action from end-effector space to joint space and send to motors. + + Args: + action: Dictionary with keys 'delta_x', 'delta_y', 'delta_z' for end-effector control + or a numpy array with [delta_x, delta_y, delta_z] + + Returns: + The joint-space action that was sent to the motors + """ + + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + # Convert action to numpy array if not already + if isinstance(action, dict): + if all(k in action for k in ["delta_x", "delta_y", "delta_z"]): + delta_ee = np.array( + [ + action["delta_x"] * self.config.end_effector_step_sizes["x"], + action["delta_y"] * self.config.end_effector_step_sizes["y"], + action["delta_z"] * self.config.end_effector_step_sizes["z"], + ], + dtype=np.float32, + ) + if "gripper" not in action: + action["gripper"] = [1.0] + action = np.append(delta_ee, action["gripper"]) + else: + logger.warning( + f"Expected action keys 'delta_x', 'delta_y', 'delta_z', got {list(action.keys())}" + ) + action = np.zeros(4, dtype=np.float32) + + if self.current_joint_pos is None: + # Read current joint positions + current_joint_pos = self.bus.sync_read("Present_Position") + self.current_joint_pos = np.array([current_joint_pos[name] for name in self.bus.motors]) + + # Calculate current end-effector position using forward kinematics + if self.current_ee_pos is None: + self.current_ee_pos = self.kinematics.forward_kinematics(self.current_joint_pos, frame=EE_FRAME) + + # Set desired end-effector position by adding delta + desired_ee_pos = np.eye(4) + desired_ee_pos[:3, :3] = self.current_ee_pos[:3, :3] # Keep orientation + + # Add delta to position and clip to bounds + desired_ee_pos[:3, 3] = self.current_ee_pos[:3, 3] + action[:3] + if self.end_effector_bounds is not None: + desired_ee_pos[:3, 3] = np.clip( + desired_ee_pos[:3, 3], + self.end_effector_bounds["min"], + self.end_effector_bounds["max"], + ) + + # Compute inverse kinematics to get joint positions + target_joint_values_in_degrees = self.kinematics.ik( + self.current_joint_pos, desired_ee_pos, position_only=True, frame=EE_FRAME + ) + + target_joint_values_in_degrees = np.clip(target_joint_values_in_degrees, -180.0, 180.0) + # Create joint space action dictionary + joint_action = { + f"{key}.pos": target_joint_values_in_degrees[i] for i, key in enumerate(self.bus.motors.keys()) + } + + # Handle gripper separately if included in action + # Gripper delta action is in the range 0 - 2, + # We need to shift the action to the range -1, 1 so that we can expand it to -Max_gripper_pos, Max_gripper_pos + joint_action["gripper.pos"] = np.clip( + self.current_joint_pos[-1] + (action[-1] - 1) * self.config.max_gripper_pos, + 5, + self.config.max_gripper_pos, + ) + + self.current_ee_pos = desired_ee_pos.copy() + self.current_joint_pos = target_joint_values_in_degrees.copy() + self.current_joint_pos[-1] = joint_action["gripper.pos"] + + # Send joint space action to parent class + return super().send_action(joint_action) + + def get_observation(self) -> dict[str, Any]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + # Read arm position + start = time.perf_counter() + obs_dict = self.bus.sync_read("Present_Position") + obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()} + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read state: {dt_ms:.1f}ms") + + # Capture images from cameras + for cam_key, cam in self.cameras.items(): + start = time.perf_counter() + obs_dict[cam_key] = cam.async_read() + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") + + return obs_dict + + def reset(self): + self.current_ee_pos = None + self.current_joint_pos = None diff --git a/lerobot/common/robots/so101_follower/__init__.py b/lerobot/common/robots/so101_follower/__init__.py new file mode 100644 index 000000000..f6615b15b --- /dev/null +++ b/lerobot/common/robots/so101_follower/__init__.py @@ -0,0 +1,2 @@ +from .config_so101_follower import SO101FollowerConfig +from .so101_follower import SO101Follower diff --git a/lerobot/common/robots/so101_follower/config_so101_follower.py b/lerobot/common/robots/so101_follower/config_so101_follower.py new file mode 100644 index 000000000..6dbf21fd5 --- /dev/null +++ b/lerobot/common/robots/so101_follower/config_so101_follower.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.common.cameras import CameraConfig + +from ..config import RobotConfig + + +@RobotConfig.register_subclass("so101_follower") +@dataclass +class SO101FollowerConfig(RobotConfig): + # Port to connect to the arm + port: str + + disable_torque_on_disconnect: bool = True + + # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. + # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as + # the number of motors in your follower arms. + max_relative_target: int | None = None + + # cameras + cameras: dict[str, CameraConfig] = field(default_factory=dict) + + # Set to `True` for backward compatibility with previous policies/dataset + use_degrees: bool = False diff --git a/lerobot/common/robots/so101_follower/so101.mdx b/lerobot/common/robots/so101_follower/so101.mdx new file mode 100644 index 000000000..36bd4fc70 --- /dev/null +++ b/lerobot/common/robots/so101_follower/so101.mdx @@ -0,0 +1,381 @@ +# SO-101 + +In the steps below, we explain how to assemble our flagship robot, the SO-101. + +## Source the parts + +Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with a link to source the parts, as well as the instructions to 3D print the parts. +And advise if it's your first time printing or if you don't own a 3D printer. + +## Install LeRobot 🤗 + +To install LeRobot, follow our [Installation Guide](./installation) + +In addition to these instructions, you need to install the Feetech SDK: +```bash +pip install -e ".[feetech]" +``` + +## Step-by-Step Assembly Instructions + +The follower arm uses 6x STS3215 motors with 1/345 gearing. The leader, however, uses three differently geared motors to make sure it can both sustain its own weight and it can be moved without requiring much force. Which motor is needed for which joint is shown in the table below. + +| Leader-Arm Axis | Motor | Gear Ratio | +|-----------------|:-------:|:----------:| +| Base / Shoulder Pan | 1 | 1 / 191 | +| Shoulder Lift | 2 | 1 / 345 | +| Elbow Flex | 3 | 1 / 191 | +| Wrist Flex | 4 | 1 / 147 | +| Wrist Roll | 5 | 1 / 147 | +| Gripper | 6 | 1 / 147 | + +### Clean Parts +Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material. + +### Joint 1 + +- Place the first motor into the base. +- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom. +- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side). +- Install both motor horns, securing the top horn with a M3x6mm screw. +- Attach the shoulder part. +- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom +- Add the shoulder motor holder. + +
+ +
+ +### Joint 2 + +- Slide the second motor in from the top. +- Fasten the second motor with 4 M2x6mm screws. +- Attach both motor horns to motor 2, again use the M3x6mm horn screw. +- Attach the upper arm with 4 M3x6mm screws on each side. + +
+ +
+ +### Joint 3 + +- Insert motor 3 and fasten using 4 M2x6mm screws +- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw. +- Connect the forearm to motor 3 using 4 M3x6mm screws on each side. + +
+ +
+ +### Joint 4 + +- Slide over motor holder 4. +- Slide in motor 4. +- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw. + +
+ +
+ +### Joint 5 + +- Insert motor 5 into the wrist holder and secure it with 2 M2x6mm front screws. +- Install only one motor horn on the wrist motor and secure it with a M3x6mm horn screw. +- Secure the wrist to motor 4 using 4 M3x6mm screws on both sides. + +
+ +
+ +### Gripper / Handle + + + + +- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws. +- Insert the gripper motor and secure it with 2 M2x6mm screws on each side. +- Attach the motor horns and again use a M3x6mm horn screw. +- Install the gripper claw and secure it with 4 M3x6mm screws on both sides. + +
+ +
+ +
+ + +- Mount the leader holder onto the wrist and secure it with 4 M3x6mm screws. +- Attach the handle to motor 5 using 1 M2x6mm screw. +- Insert the gripper motor, secure it with 2 M2x6mm screws on each side, attach a motor horn using a M3x6mm horn screw. +- Attach the follower trigger with 4 M3x6mm screws. + +
+ +
+ +
+
+ +## Configure the motors + +### 1. Find the USB ports associated with each arm + +To find the port for each bus servo adapter, run this script: +```bash +python lerobot/find_port.py +``` + + + + +Example output: + +``` +Finding all available ports for the MotorBus. +['/dev/tty.usbmodem575E0032081', '/dev/tty.usbmodem575E0031751'] +Remove the USB cable from your MotorsBus and press Enter when done. + +[...Disconnect corresponding leader or follower arm and press Enter...] + +The port of this MotorsBus is /dev/tty.usbmodem575E0032081 +Reconnect the USB cable. +``` + +Where the found port is: `/dev/tty.usbmodem575E0032081` corresponding to your leader or follower arm. + + + + +On Linux, you might need to give access to the USB ports by running: +```bash +sudo chmod 666 /dev/ttyACM0 +sudo chmod 666 /dev/ttyACM1 +``` + +Example output: + +``` +Finding all available ports for the MotorBus. +['/dev/ttyACM0', '/dev/ttyACM1'] +Remove the usb cable from your MotorsBus and press Enter when done. + +[...Disconnect corresponding leader or follower arm and press Enter...] + +The port of this MotorsBus is /dev/ttyACM1 +Reconnect the USB cable. +``` + +Where the found port is: `/dev/ttyACM1` corresponding to your leader or follower arm. + + + + +### 2. Set the motors ids and baudrates + +Each motor is identified by a unique id on the bus. When brand new, motors usually come with a default id of `1`. For the communication to work properly between the motors and the controller, we first need to set a unique, different id to each motor. Additionally, the speed at which data is transmitted on the bus is determined by the baudrate. In order to talk to each other, the controller and all the motors need to be configured with the same baudrate. + +To that end, we first need to connect to each motor individually with the controller in order to set these. Since we will write these parameters in the non-volatile section of the motors' internal memory (EEPROM), we'll only need to do this once. + +If you are repurposing motors from another robot, you will probably also need to perform this step as the ids and baudrate likely won't match. + +The video below shows the sequence of steps for setting the motor ids. + +##### Setup motors video + +
+ +
+ +#### Follower + +Connect the usb cable from your computer and the power supply to the follower arm's controller board. Then, run the following command or run the API example with the port you got from the previous step. You'll also need to give your leader arm a name with the `id` parameter. + + + + +```bash +python -m lerobot.setup_motors \ + --robot.type=so101_follower \ + --robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step +``` + + + +```python +from lerobot.common.robots.so101_follower import SO101Follower, SO101FollowerConfig + +config = SO101FollowerConfig( + port="/dev/tty.usbmodem585A0076841", + id="my_awesome_follower_arm", +) +follower = SO101Follower(config) +follower.setup_motors() +``` + + + +You should see the following instruction +```bash +Connect the controller board to the 'gripper' motor only and press enter. +``` + +As instructed, plug the gripper's motor. Make sure it's the only motor connected to the board, and that the motor itself is not yet daisy-chained to any other motor. As you press `[Enter]`, the script will automatically set the id and baudrate for that motor. + +
+Troubleshooting + + If you get an error at that point, check your cables and make sure they are plugged in properly: +
    +
  • Power supply
  • +
  • USB cable between your computer and the controller board
  • +
  • The 3-pin cable from the controller board to the motor
  • +
+ + If you are using a Waveshare controller board, make sure that the two jumpers are set on the `B` channel (USB). +
+ +You should then see the following message: +```bash +'gripper' motor id set to 6 +``` + +Followed by the next instruction: +```bash +Connect the controller board to the 'wrist_roll' motor only and press enter. +``` + +You can disconnect the 3-pin cable from the controller board, but you can leave it connected to the gripper motor on the other end, as it will already be in the right place. Now, plug in another 3-pin cable to the wrist roll motor and connect it to the controller board. As with the previous motor, make sure it is the only motor connected to the board and that the motor itself isn't connected to any other one. + +Repeat the operation for each motor as instructed. + +> [!TIP] +> Check your cabling at each step before pressing Enter. For instance, the power supply cable might disconnect as you manipulate the board. + +When you are done, the script will simply finish, at which point the motors are ready to be used. You can now plug the 3-pin cable from each motor to the next one, and the cable from the first motor (the 'shoulder pan' with id=1) to the controller board, which can now be attached to the base of the arm. + +#### Leader +Do the same steps for the leader arm. + + + + +```bash +python -m lerobot.setup_motors \ + --teleop.type=so101_leader \ + --teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step +``` + + + +```python +from lerobot.common.teleoperators.so101_leader import SO101Leader, SO101LeaderConfig + +config = SO101LeaderConfig( + port="/dev/tty.usbmodem585A0076841", + id="my_awesome_leader_arm", +) +leader = SO101Leader(config) +leader.setup_motors() +``` + + + +## Calibrate + +Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. +The calibration process is very important because it allows a neural network trained on one robot to work on another. + +#### Follower + +Run the following command or API example to calibrate the follower arm: + + + + +```bash +python -m lerobot.calibrate \ + --robot.type=so101_follower \ + --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot + --robot.id=my_awesome_follower_arm # <- Give the robot a unique name +``` + + + +```python +from lerobot.common.robots.so101_follower import SO101FollowerConfig, SO101Follower + +config = SO101FollowerConfig( + port="/dev/tty.usbmodem585A0076891", + id="my_awesome_follower_arm", +) + +follower = SO101Follower(config) +follower.connect(calibrate=False) +follower.calibrate() +follower.disconnect() +``` + + + +The video below shows how to perform the calibration. First you need to move the robot to the position where all joints are in the middle of their ranges. Then after pressing enter you have to move each joint through its full range of motion. + +##### Calibration video + +
+ +
+ +#### Leader + +Do the same steps to calibrate the leader arm, run the following command or API example: + + + + +```bash +python -m lerobot.calibrate \ + --teleop.type=so101_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot + --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name +``` + + + +```python +from lerobot.common.teleoperators.so101_leader import SO101LeaderConfig, SO101Leader + +config = SO101LeaderConfig( + port="/dev/tty.usbmodem58760431551", + id="my_awesome_leader_arm", +) + +leader = SO101Leader(config) +leader.connect(calibrate=False) +leader.calibrate() +leader.disconnect() +``` + + + +Congrats 🎉, your robot is all set to learn a task on its own. Start training it by following this tutorial: [Getting started with real-world robots](./getting_started_real_world_robot) + +> [!TIP] +> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb). diff --git a/lerobot/common/robots/so101_follower/so101_follower.py b/lerobot/common/robots/so101_follower/so101_follower.py new file mode 100644 index 000000000..a3c7aa0c2 --- /dev/null +++ b/lerobot/common/robots/so101_follower/so101_follower.py @@ -0,0 +1,210 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from functools import cached_property +from typing import Any + +from lerobot.common.cameras.utils import make_cameras_from_configs +from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.common.motors.feetech import ( + FeetechMotorsBus, + OperatingMode, +) + +from ..robot import Robot +from ..utils import ensure_safe_goal_position +from .config_so101_follower import SO101FollowerConfig + +logger = logging.getLogger(__name__) + + +class SO101Follower(Robot): + """ + SO-101 Follower Arm designed by TheRobotStudio and Hugging Face. + """ + + config_class = SO101FollowerConfig + name = "so101_follower" + + def __init__(self, config: SO101FollowerConfig): + super().__init__(config) + self.config = config + norm_mode_body = MotorNormMode.DEGREES if config.use_degrees else MotorNormMode.RANGE_M100_100 + self.bus = FeetechMotorsBus( + port=self.config.port, + motors={ + "shoulder_pan": Motor(1, "sts3215", norm_mode_body), + "shoulder_lift": Motor(2, "sts3215", norm_mode_body), + "elbow_flex": Motor(3, "sts3215", norm_mode_body), + "wrist_flex": Motor(4, "sts3215", norm_mode_body), + "wrist_roll": Motor(5, "sts3215", norm_mode_body), + "gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100), + }, + calibration=self.calibration, + ) + self.cameras = make_cameras_from_configs(config.cameras) + + @property + def _motors_ft(self) -> dict[str, type]: + return {f"{motor}.pos": float for motor in self.bus.motors} + + @property + def _cameras_ft(self) -> dict[str, tuple]: + return { + cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras + } + + @cached_property + def observation_features(self) -> dict[str, type | tuple]: + return {**self._motors_ft, **self._cameras_ft} + + @cached_property + def action_features(self) -> dict[str, type]: + return self._motors_ft + + @property + def is_connected(self) -> bool: + return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + + def connect(self, calibrate: bool = True) -> None: + """ + We assume that at connection time, arm is in a rest position, + and torque can be safely disabled to run calibration. + """ + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + self.bus.connect() + if not self.is_calibrated and calibrate: + self.calibrate() + + for cam in self.cameras.values(): + cam.connect() + + self.configure() + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + return self.bus.is_calibrated + + def calibrate(self) -> None: + logger.info(f"\nRunning calibration of {self}") + self.bus.disable_torque() + for motor in self.bus.motors: + self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value) + + input(f"Move {self} to the middle of its range of motion and press ENTER....") + homing_offsets = self.bus.set_half_turn_homings() + + print( + "Move all joints sequentially through their entire ranges " + "of motion.\nRecording positions. Press ENTER to stop..." + ) + range_mins, range_maxes = self.bus.record_ranges_of_motion() + + self.calibration = {} + for motor, m in self.bus.motors.items(): + self.calibration[motor] = MotorCalibration( + id=m.id, + drive_mode=0, + homing_offset=homing_offsets[motor], + range_min=range_mins[motor], + range_max=range_maxes[motor], + ) + + self.bus.write_calibration(self.calibration) + self._save_calibration() + print("Calibration saved to", self.calibration_fpath) + + def configure(self) -> None: + with self.bus.torque_disabled(): + self.bus.configure_motors() + for motor in self.bus.motors: + self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value) + # Set P_Coefficient to lower value to avoid shakiness (Default is 32) + self.bus.write("P_Coefficient", motor, 16) + # Set I_Coefficient and D_Coefficient to default value 0 and 32 + self.bus.write("I_Coefficient", motor, 0) + self.bus.write("D_Coefficient", motor, 32) + + def setup_motors(self) -> None: + for motor in reversed(self.bus.motors): + input(f"Connect the controller board to the '{motor}' motor only and press enter.") + self.bus.setup_motor(motor) + print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + + def get_observation(self) -> dict[str, Any]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + # Read arm position + start = time.perf_counter() + obs_dict = self.bus.sync_read("Present_Position") + obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()} + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read state: {dt_ms:.1f}ms") + + # Capture images from cameras + for cam_key, cam in self.cameras.items(): + start = time.perf_counter() + obs_dict[cam_key] = cam.async_read() + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") + + return obs_dict + + def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + """Command arm to move to a target joint configuration. + + The relative action magnitude may be clipped depending on the configuration parameter + `max_relative_target`. In this case, the action sent differs from original action. + Thus, this function always returns the action actually sent. + + Raises: + RobotDeviceNotConnectedError: if robot is not connected. + + Returns: + the action sent to the motors, potentially clipped. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} + + # Cap goal position when too far away from present position. + # /!\ Slower fps expected due to reading from the follower. + if self.config.max_relative_target is not None: + present_pos = self.bus.sync_read("Present_Position") + goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()} + goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target) + + # Send goal position to the arm + self.bus.sync_write("Goal_Position", goal_pos) + return {f"{motor}.pos": val for motor, val in goal_pos.items()} + + def disconnect(self): + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + self.bus.disconnect(self.config.disable_torque_on_disconnect) + for cam in self.cameras.values(): + cam.disconnect() + + logger.info(f"{self} disconnected.") diff --git a/examples/8_use_stretch.md b/lerobot/common/robots/stretch3/README.md similarity index 96% rename from examples/8_use_stretch.md rename to lerobot/common/robots/stretch3/README.md index a7a7dde17..982e72571 100644 --- a/examples/8_use_stretch.md +++ b/lerobot/common/robots/stretch3/README.md @@ -99,7 +99,7 @@ This is equivalent to running `stretch_robot_home.py` > **Note:** If you run any of the LeRobot scripts below and Stretch is not properly homed, it will automatically home/calibrate first. **Teleoperate** -Before trying teleoperation, you need activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation). +Before trying teleoperation, you need to activate the gamepad controller by pressing the middle button. For more info, see Stretch's [doc](https://docs.hello-robot.com/0.3/getting_started/hello_robot/#gamepad-teleoperation). Now try out teleoperation (see above documentation to learn about the gamepad controls): diff --git a/lerobot/common/robots/stretch3/__init__.py b/lerobot/common/robots/stretch3/__init__.py new file mode 100644 index 000000000..e2a859cde --- /dev/null +++ b/lerobot/common/robots/stretch3/__init__.py @@ -0,0 +1,2 @@ +from .configuration_stretch3 import Stretch3RobotConfig +from .robot_stretch3 import Stretch3Robot diff --git a/lerobot/common/robots/stretch3/configuration_stretch3.py b/lerobot/common/robots/stretch3/configuration_stretch3.py new file mode 100644 index 000000000..e62e4fa01 --- /dev/null +++ b/lerobot/common/robots/stretch3/configuration_stretch3.py @@ -0,0 +1,58 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.common.cameras import CameraConfig +from lerobot.common.cameras.opencv import OpenCVCameraConfig +from lerobot.common.cameras.realsense import RealSenseCameraConfig + +from ..config import RobotConfig + + +@RobotConfig.register_subclass("stretch3") +@dataclass +class Stretch3RobotConfig(RobotConfig): + # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. + # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as + # the number of motors in your follower arms. + max_relative_target: int | None = None + + # cameras + cameras: dict[str, CameraConfig] = field( + default_factory=lambda: { + "navigation": OpenCVCameraConfig( + index_or_path="/dev/hello-nav-head-camera", + fps=10, + width=1280, + height=720, + rotation=-90, + ), + "head": RealSenseCameraConfig( + name="Intel RealSense D435I", + fps=30, + width=640, + height=480, + rotation=90, + ), + "wrist": RealSenseCameraConfig( + name="Intel RealSense D405", + fps=30, + width=640, + height=480, + ), + } + ) + + mock: bool = False diff --git a/lerobot/common/robots/stretch3/robot_stretch3.py b/lerobot/common/robots/stretch3/robot_stretch3.py new file mode 100644 index 000000000..048db381f --- /dev/null +++ b/lerobot/common/robots/stretch3/robot_stretch3.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import numpy as np +from stretch_body.gamepad_teleop import GamePadTeleop +from stretch_body.robot import Robot as StretchAPI +from stretch_body.robot_params import RobotParams + +from lerobot.common.cameras.utils import make_cameras_from_configs +from lerobot.common.constants import OBS_IMAGES, OBS_STATE +from lerobot.common.datasets.utils import get_nested_item + +from ..robot import Robot +from .configuration_stretch3 import Stretch3RobotConfig + +# {lerobot_keys: stretch.api.keys} +STRETCH_MOTORS = { + "head_pan.pos": "head.head_pan.pos", + "head_tilt.pos": "head.head_tilt.pos", + "lift.pos": "lift.pos", + "arm.pos": "arm.pos", + "wrist_pitch.pos": "end_of_arm.wrist_pitch.pos", + "wrist_roll.pos": "end_of_arm.wrist_roll.pos", + "wrist_yaw.pos": "end_of_arm.wrist_yaw.pos", + "gripper.pos": "end_of_arm.stretch_gripper.pos", + "base_x.vel": "base.x_vel", + "base_y.vel": "base.y_vel", + "base_theta.vel": "base.theta_vel", +} + + +class Stretch3Robot(Robot): + """[Stretch 3](https://hello-robot.com/stretch-3-product), by Hello Robot.""" + + config_class = Stretch3RobotConfig + name = "stretch3" + + def __init__(self, config: Stretch3RobotConfig): + raise NotImplementedError + super().__init__(config) + + self.config = config + self.robot_type = self.config.type + + self.api = StretchAPI() + self.cameras = make_cameras_from_configs(config.cameras) + + self.is_connected = False + self.logs = {} + + self.teleop = None # TODO remove + + # TODO(aliberts): test this + RobotParams.set_logging_level("WARNING") + RobotParams.set_logging_formatter("brief_console_formatter") + + self.state_keys = None + self.action_keys = None + + @property + def observation_features(self) -> dict: + return { + "dtype": "float32", + "shape": (len(STRETCH_MOTORS),), + "names": {"motors": list(STRETCH_MOTORS)}, + } + + @property + def action_features(self) -> dict: + return self.observation_features + + @property + def camera_features(self) -> dict[str, dict]: + cam_ft = {} + for cam_key, cam in self.cameras.items(): + cam_ft[cam_key] = { + "shape": (cam.height, cam.width, cam.channels), + "names": ["height", "width", "channels"], + "info": None, + } + return cam_ft + + def connect(self) -> None: + self.is_connected = self.api.startup() + if not self.is_connected: + print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'") + raise ConnectionError() + + for cam in self.cameras.values(): + cam.connect() + self.is_connected = self.is_connected and cam.is_connected + + if not self.is_connected: + print("Could not connect to the cameras, check that all cameras are plugged-in.") + raise ConnectionError() + + self.calibrate() + + def calibrate(self) -> None: + if not self.api.is_homed(): + self.api.home() + + def _get_state(self) -> dict: + status = self.api.get_status() + return {k: get_nested_item(status, v, sep=".") for k, v in STRETCH_MOTORS.items()} + + def get_observation(self) -> dict[str, np.ndarray]: + obs_dict = {} + + # Read Stretch state + before_read_t = time.perf_counter() + state = self._get_state() + self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t + + if self.state_keys is None: + self.state_keys = list(state) + + state = np.asarray(list(state.values())) + obs_dict[OBS_STATE] = state + + # Capture images from cameras + for cam_key, cam in self.cameras.items(): + before_camread_t = time.perf_counter() + obs_dict[f"{OBS_IMAGES}.{cam_key}"] = cam.async_read() + self.logs[f"read_camera_{cam_key}_dt_s"] = cam.logs["delta_timestamp_s"] + self.logs[f"async_read_camera_{cam_key}_dt_s"] = time.perf_counter() - before_camread_t + + return obs_dict + + def send_action(self, action: np.ndarray) -> np.ndarray: + if not self.is_connected: + raise ConnectionError() + + if self.teleop is None: + self.teleop = GamePadTeleop(robot_instance=False) + self.teleop.startup(robot=self) + + if self.action_keys is None: + dummy_action = self.teleop.gamepad_controller.get_state() + self.action_keys = list(dummy_action.keys()) + + action_dict = dict(zip(self.action_keys, action.tolist(), strict=True)) + + before_write_t = time.perf_counter() + self.teleop.do_motion(state=action_dict, robot=self) + self.push_command() + self.logs["write_pos_dt_s"] = time.perf_counter() - before_write_t + + # TODO(aliberts): return action_sent when motion is limited + return action + + def print_logs(self) -> None: + pass + # TODO(aliberts): move robot-specific logs logic here + + def teleop_safety_stop(self) -> None: + if self.teleop is not None: + self.teleop._safety_stop(robot=self) + + def disconnect(self) -> None: + self.api.stop() + if self.teleop is not None: + self.teleop.gamepad_controller.stop() + self.teleop.stop() + + for cam in self.cameras.values(): + cam.disconnect() + + self.is_connected = False diff --git a/lerobot/common/robots/utils.py b/lerobot/common/robots/utils.py new file mode 100644 index 000000000..ccc1c58e8 --- /dev/null +++ b/lerobot/common/robots/utils.py @@ -0,0 +1,95 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pprint import pformat + +from lerobot.common.robots import RobotConfig + +from .robot import Robot + + +def make_robot_from_config(config: RobotConfig) -> Robot: + if config.type == "koch_follower": + from .koch_follower import KochFollower + + return KochFollower(config) + elif config.type == "so100_follower": + from .so100_follower import SO100Follower + + return SO100Follower(config) + elif config.type == "so100_follower_end_effector": + from .so100_follower import SO100FollowerEndEffector + + return SO100FollowerEndEffector(config) + elif config.type == "so101_follower": + from .so101_follower import SO101Follower + + return SO101Follower(config) + elif config.type == "lekiwi": + from .lekiwi import LeKiwi + + return LeKiwi(config) + elif config.type == "stretch3": + from .stretch3 import Stretch3Robot + + return Stretch3Robot(config) + elif config.type == "viperx": + from .viperx import ViperX + + return ViperX(config) + elif config.type == "mock_robot": + from tests.mocks.mock_robot import MockRobot + + return MockRobot(config) + else: + raise ValueError(config.type) + + +def ensure_safe_goal_position( + goal_present_pos: dict[str, tuple[float, float]], max_relative_target: float | dict[float] +) -> dict[str, float]: + """Caps relative action target magnitude for safety.""" + + if isinstance(max_relative_target, float): + diff_cap = dict.fromkeys(goal_present_pos, max_relative_target) + elif isinstance(max_relative_target, dict): + if not set(goal_present_pos) == set(max_relative_target): + raise ValueError("max_relative_target keys must match those of goal_present_pos.") + diff_cap = max_relative_target + else: + raise TypeError(max_relative_target) + + warnings_dict = {} + safe_goal_positions = {} + for key, (goal_pos, present_pos) in goal_present_pos.items(): + diff = goal_pos - present_pos + max_diff = diff_cap[key] + safe_diff = min(diff, max_diff) + safe_diff = max(safe_diff, -max_diff) + safe_goal_pos = present_pos + safe_diff + safe_goal_positions[key] = safe_goal_pos + if abs(safe_goal_pos - goal_pos) > 1e-4: + warnings_dict[key] = { + "original goal_pos": goal_pos, + "safe goal_pos": safe_goal_pos, + } + + if warnings_dict: + logging.warning( + "Relative goal position magnitude had to be clamped to be safe.\n" + f"{pformat(warnings_dict, indent=4)}" + ) + + return safe_goal_positions diff --git a/lerobot/common/robots/viperx/README.md b/lerobot/common/robots/viperx/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/lerobot/common/robots/viperx/__init__.py b/lerobot/common/robots/viperx/__init__.py new file mode 100644 index 000000000..522d02f1c --- /dev/null +++ b/lerobot/common/robots/viperx/__init__.py @@ -0,0 +1,2 @@ +from .config_viperx import ViperXConfig +from .viperx import ViperX diff --git a/lerobot/common/robots/viperx/config_viperx.py b/lerobot/common/robots/viperx/config_viperx.py new file mode 100644 index 000000000..6c7e2cc84 --- /dev/null +++ b/lerobot/common/robots/viperx/config_viperx.py @@ -0,0 +1,45 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.common.cameras import CameraConfig + +from ..config import RobotConfig + + +@RobotConfig.register_subclass("viperx") +@dataclass +class ViperXConfig(RobotConfig): + port: str # Port to connect to the arm + + disable_torque_on_disconnect: bool = True + + # /!\ FOR SAFETY, READ THIS /!\ + # `max_relative_target` limits the magnitude of the relative positional target vector for safety purposes. + # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as + # the number of motors in your follower arms. + # For Aloha, for every goal position request, motor rotations are capped at 5 degrees by default. + # When you feel more confident with teleoperation or running the policy, you can extend + # this safety limit and even removing it by setting it to `null`. + # Also, everything is expected to work safely out-of-the-box, but we highly advise to + # first try to teleoperate the grippers only (by commenting out the rest of the motors in this yaml), + # then to gradually add more motors (by uncommenting), until you can teleoperate both arms fully + max_relative_target: int | None = 5 + + # cameras + cameras: dict[str, CameraConfig] = field(default_factory=dict) + # Troubleshooting: If one of your IntelRealSense cameras freeze during + # data recording due to bandwidth limit, you might need to plug the camera + # on another USB hub or PCIe card. diff --git a/lerobot/common/robots/viperx/viperx.py b/lerobot/common/robots/viperx/viperx.py new file mode 100644 index 000000000..8ed8ef74c --- /dev/null +++ b/lerobot/common/robots/viperx/viperx.py @@ -0,0 +1,233 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from functools import cached_property +from typing import Any + +from lerobot.common.cameras.utils import make_cameras_from_configs +from lerobot.common.constants import OBS_STATE +from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.common.motors.dynamixel import ( + DynamixelMotorsBus, + OperatingMode, +) + +from ..robot import Robot +from ..utils import ensure_safe_goal_position +from .config_viperx import ViperXConfig + +logger = logging.getLogger(__name__) + + +class ViperX(Robot): + """ + [ViperX](https://www.trossenrobotics.com/viperx-300) developed by Trossen Robotics + """ + + config_class = ViperXConfig + name = "viperx" + + def __init__( + self, + config: ViperXConfig, + ): + raise NotImplementedError + super().__init__(config) + self.config = config + self.bus = DynamixelMotorsBus( + port=self.config.port, + motors={ + "waist": Motor(1, "xm540-w270", MotorNormMode.RANGE_M100_100), + "shoulder": Motor(2, "xm540-w270", MotorNormMode.RANGE_M100_100), + "shoulder_shadow": Motor(3, "xm540-w270", MotorNormMode.RANGE_M100_100), + "elbow": Motor(4, "xm540-w270", MotorNormMode.RANGE_M100_100), + "elbow_shadow": Motor(5, "xm540-w270", MotorNormMode.RANGE_M100_100), + "forearm_roll": Motor(6, "xm540-w270", MotorNormMode.RANGE_M100_100), + "wrist_angle": Motor(7, "xm540-w270", MotorNormMode.RANGE_M100_100), + "wrist_rotate": Motor(8, "xm430-w350", MotorNormMode.RANGE_M100_100), + "gripper": Motor(9, "xm430-w350", MotorNormMode.RANGE_0_100), + }, + ) + self.cameras = make_cameras_from_configs(config.cameras) + + @property + def _motors_ft(self) -> dict[str, type]: + return {f"{motor}.pos": float for motor in self.bus.motors} + + @property + def _cameras_ft(self) -> dict[str, tuple]: + return { + cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras + } + + @cached_property + def observation_features(self) -> dict[str, type | tuple]: + return {**self._motors_ft, **self._cameras_ft} + + @cached_property + def action_features(self) -> dict[str, type]: + return self._motors_ft + + @property + def is_connected(self) -> bool: + return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values()) + + def connect(self, calibrate: bool = True) -> None: + """ + We assume that at connection time, arm is in a rest position, + and torque can be safely disabled to run calibration. + """ + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + self.bus.connect() + if not self.is_calibrated and calibrate: + self.calibrate() + + for cam in self.cameras.values(): + cam.connect() + + self.configure() + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + return self.bus.is_calibrated + + def calibrate(self) -> None: + raise NotImplementedError # TODO(aliberts): adapt code below (copied from koch + logger.info(f"\nRunning calibration of {self}") + self.bus.disable_torque() + for motor in self.bus.motors: + self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value) + + input("Move robot to the middle of its range of motion and press ENTER....") + homing_offsets = self.bus.set_half_turn_homings() + + full_turn_motors = ["shoulder_pan", "wrist_roll"] + unknown_range_motors = [motor for motor in self.bus.motors if motor not in full_turn_motors] + print( + f"Move all joints except {full_turn_motors} sequentially through their entire " + "ranges of motion.\nRecording positions. Press ENTER to stop..." + ) + range_mins, range_maxes = self.bus.record_ranges_of_motion(unknown_range_motors) + for motor in full_turn_motors: + range_mins[motor] = 0 + range_maxes[motor] = 4095 + + self.calibration = {} + for motor, m in self.bus.motors.items(): + self.calibration[motor] = MotorCalibration( + id=m.id, + drive_mode=0, + homing_offset=homing_offsets[motor], + range_min=range_mins[motor], + range_max=range_maxes[motor], + ) + + self.bus.write_calibration(self.calibration) + self._save_calibration() + logger.info(f"Calibration saved to {self.calibration_fpath}") + + def configure(self) -> None: + with self.bus.torque_disabled(): + self.bus.configure_motors() + + # Set secondary/shadow ID for shoulder and elbow. These joints have two motors. + # As a result, if only one of them is required to move to a certain position, + # the other will follow. This is to avoid breaking the motors. + self.bus.write("Secondary_ID", "shoulder_shadow", 2) + self.bus.write("Secondary_ID", "elbow_shadow", 4) + + # Set a velocity limit of 131 as advised by Trossen Robotics + # TODO(aliberts): remove as it's actually useless in position control + self.bus.write("Velocity_Limit", 131) + + # Use 'extended position mode' for all motors except gripper, because in joint mode the servos + # can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling + # the arm, you could end up with a servo with a position 0 or 4095 at a crucial point. + # See: https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11 + for motor in self.bus.motors: + if motor != "gripper": + self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value) + + # Use 'position control current based' for follower gripper to be limited by the limit of the + # current. It can grasp an object without forcing too much even tho, it's goal position is a + # complete grasp (both gripper fingers are ordered to join and reach a touch). + self.bus.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value) + + def get_observation(self) -> dict[str, Any]: + """The returned observations do not have a batch dimension.""" + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + obs_dict = {} + + # Read arm position + start = time.perf_counter() + obs_dict[OBS_STATE] = self.bus.sync_read("Present_Position") + obs_dict = {f"{motor}.pos": val for motor, val in obs_dict.items()} + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read state: {dt_ms:.1f}ms") + + # Capture images from cameras + for cam_key, cam in self.cameras.items(): + start = time.perf_counter() + obs_dict[cam_key] = cam.async_read() + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") + + return obs_dict + + def send_action(self, action: dict[str, float]) -> dict[str, float]: + """Command arm to move to a target joint configuration. + + The relative action magnitude may be clipped depending on the configuration parameter + `max_relative_target`. In this case, the action sent differs from original action. + Thus, this function always returns the action actually sent. + + Args: + action (dict[str, float]): The goal positions for the motors. + + Returns: + dict[str, float]: The action sent to the motors, potentially clipped. + """ + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")} + + # Cap goal position when too far away from present position. + # /!\ Slower fps expected due to reading from the follower. + if self.config.max_relative_target is not None: + present_pos = self.bus.sync_read("Present_Position") + goal_present_pos = {key: (g_pos, present_pos[key]) for key, g_pos in goal_pos.items()} + goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target) + + # Send goal position to the arm + self.bus.sync_write("Goal_Position", goal_pos) + return {f"{motor}.pos": val for motor, val in goal_pos.items()} + + def disconnect(self): + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + self.bus.disconnect(self.config.disable_torque_on_disconnect) + for cam in self.cameras.values(): + cam.disconnect() + + logger.info(f"{self} disconnected.") diff --git a/lerobot/common/teleoperators/__init__.py b/lerobot/common/teleoperators/__init__.py new file mode 100644 index 000000000..ec93547f7 --- /dev/null +++ b/lerobot/common/teleoperators/__init__.py @@ -0,0 +1,3 @@ +from .config import TeleoperatorConfig +from .teleoperator import Teleoperator +from .utils import make_teleoperator_from_config diff --git a/lerobot/common/teleoperators/config.py b/lerobot/common/teleoperators/config.py new file mode 100644 index 000000000..1b42b4edb --- /dev/null +++ b/lerobot/common/teleoperators/config.py @@ -0,0 +1,31 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from dataclasses import dataclass +from pathlib import Path + +import draccus + + +@dataclass(kw_only=True) +class TeleoperatorConfig(draccus.ChoiceRegistry, abc.ABC): + # Allows to distinguish between different teleoperators of the same type + id: str | None = None + # Directory to store calibration file + calibration_dir: Path | None = None + + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) diff --git a/lerobot/common/teleoperators/gamepad/__init__.py b/lerobot/common/teleoperators/gamepad/__init__.py new file mode 100644 index 000000000..6f9f7fbd9 --- /dev/null +++ b/lerobot/common/teleoperators/gamepad/__init__.py @@ -0,0 +1,18 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_gamepad import GamepadTeleopConfig +from .teleop_gamepad import GamepadTeleop diff --git a/lerobot/common/teleoperators/gamepad/configuration_gamepad.py b/lerobot/common/teleoperators/gamepad/configuration_gamepad.py new file mode 100644 index 000000000..b3a565c07 --- /dev/null +++ b/lerobot/common/teleoperators/gamepad/configuration_gamepad.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("gamepad") +@dataclass +class GamepadTeleopConfig(TeleoperatorConfig): + use_gripper: bool = True diff --git a/lerobot/common/teleoperators/gamepad/gamepad_utils.py b/lerobot/common/teleoperators/gamepad/gamepad_utils.py new file mode 100644 index 000000000..21a293c77 --- /dev/null +++ b/lerobot/common/teleoperators/gamepad/gamepad_utils.py @@ -0,0 +1,480 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + + +class InputController: + """Base class for input controllers that generate motion deltas.""" + + def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0): + """ + Initialize the controller. + + Args: + x_step_size: Base movement step size in meters + y_step_size: Base movement step size in meters + z_step_size: Base movement step size in meters + """ + self.x_step_size = x_step_size + self.y_step_size = y_step_size + self.z_step_size = z_step_size + self.running = True + self.episode_end_status = None # None, "success", or "failure" + self.intervention_flag = False + self.open_gripper_command = False + self.close_gripper_command = False + + def start(self): + """Start the controller and initialize resources.""" + pass + + def stop(self): + """Stop the controller and release resources.""" + pass + + def get_deltas(self): + """Get the current movement deltas (dx, dy, dz) in meters.""" + return 0.0, 0.0, 0.0 + + def should_quit(self): + """Return True if the user has requested to quit.""" + return not self.running + + def update(self): + """Update controller state - call this once per frame.""" + pass + + def __enter__(self): + """Support for use in 'with' statements.""" + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Ensure resources are released when exiting 'with' block.""" + self.stop() + + def get_episode_end_status(self): + """ + Get the current episode end status. + + Returns: + None if episode should continue, "success" or "failure" otherwise + """ + status = self.episode_end_status + self.episode_end_status = None # Reset after reading + return status + + def should_intervene(self): + """Return True if intervention flag was set.""" + return self.intervention_flag + + def gripper_command(self): + """Return the current gripper command.""" + if self.open_gripper_command == self.close_gripper_command: + return "stay" + elif self.open_gripper_command: + return "open" + elif self.close_gripper_command: + return "close" + + +class KeyboardController(InputController): + """Generate motion deltas from keyboard input.""" + + def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0): + super().__init__(x_step_size, y_step_size, z_step_size) + self.key_states = { + "forward_x": False, + "backward_x": False, + "forward_y": False, + "backward_y": False, + "forward_z": False, + "backward_z": False, + "quit": False, + "success": False, + "failure": False, + } + self.listener = None + + def start(self): + """Start the keyboard listener.""" + from pynput import keyboard + + def on_press(key): + try: + if key == keyboard.Key.up: + self.key_states["forward_x"] = True + elif key == keyboard.Key.down: + self.key_states["backward_x"] = True + elif key == keyboard.Key.left: + self.key_states["forward_y"] = True + elif key == keyboard.Key.right: + self.key_states["backward_y"] = True + elif key == keyboard.Key.shift: + self.key_states["backward_z"] = True + elif key == keyboard.Key.shift_r: + self.key_states["forward_z"] = True + elif key == keyboard.Key.esc: + self.key_states["quit"] = True + self.running = False + return False + elif key == keyboard.Key.enter: + self.key_states["success"] = True + self.episode_end_status = "success" + elif key == keyboard.Key.backspace: + self.key_states["failure"] = True + self.episode_end_status = "failure" + except AttributeError: + pass + + def on_release(key): + try: + if key == keyboard.Key.up: + self.key_states["forward_x"] = False + elif key == keyboard.Key.down: + self.key_states["backward_x"] = False + elif key == keyboard.Key.left: + self.key_states["forward_y"] = False + elif key == keyboard.Key.right: + self.key_states["backward_y"] = False + elif key == keyboard.Key.shift: + self.key_states["backward_z"] = False + elif key == keyboard.Key.shift_r: + self.key_states["forward_z"] = False + elif key == keyboard.Key.enter: + self.key_states["success"] = False + elif key == keyboard.Key.backspace: + self.key_states["failure"] = False + except AttributeError: + pass + + self.listener = keyboard.Listener(on_press=on_press, on_release=on_release) + self.listener.start() + + print("Keyboard controls:") + print(" Arrow keys: Move in X-Y plane") + print(" Shift and Shift_R: Move in Z axis") + print(" Enter: End episode with SUCCESS") + print(" Backspace: End episode with FAILURE") + print(" ESC: Exit") + + def stop(self): + """Stop the keyboard listener.""" + if self.listener and self.listener.is_alive(): + self.listener.stop() + + def get_deltas(self): + """Get the current movement deltas from keyboard state.""" + delta_x = delta_y = delta_z = 0.0 + + if self.key_states["forward_x"]: + delta_x += self.x_step_size + if self.key_states["backward_x"]: + delta_x -= self.x_step_size + if self.key_states["forward_y"]: + delta_y += self.y_step_size + if self.key_states["backward_y"]: + delta_y -= self.y_step_size + if self.key_states["forward_z"]: + delta_z += self.z_step_size + if self.key_states["backward_z"]: + delta_z -= self.z_step_size + + return delta_x, delta_y, delta_z + + def should_quit(self): + """Return True if ESC was pressed.""" + return self.key_states["quit"] + + def should_save(self): + """Return True if Enter was pressed (save episode).""" + return self.key_states["success"] or self.key_states["failure"] + + +class GamepadController(InputController): + """Generate motion deltas from gamepad input.""" + + def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0, deadzone=0.1): + super().__init__(x_step_size, y_step_size, z_step_size) + self.deadzone = deadzone + self.joystick = None + self.intervention_flag = False + + def start(self): + """Initialize pygame and the gamepad.""" + import pygame + + pygame.init() + pygame.joystick.init() + + if pygame.joystick.get_count() == 0: + logging.error("No gamepad detected. Please connect a gamepad and try again.") + self.running = False + return + + self.joystick = pygame.joystick.Joystick(0) + self.joystick.init() + logging.info(f"Initialized gamepad: {self.joystick.get_name()}") + + print("Gamepad controls:") + print(" Left analog stick: Move in X-Y plane") + print(" Right analog stick (vertical): Move in Z axis") + print(" B/Circle button: Exit") + print(" Y/Triangle button: End episode with SUCCESS") + print(" A/Cross button: End episode with FAILURE") + print(" X/Square button: Rerecord episode") + + def stop(self): + """Clean up pygame resources.""" + import pygame + + if pygame.joystick.get_init(): + if self.joystick: + self.joystick.quit() + pygame.joystick.quit() + pygame.quit() + + def update(self): + """Process pygame events to get fresh gamepad readings.""" + import pygame + + for event in pygame.event.get(): + if event.type == pygame.JOYBUTTONDOWN: + if event.button == 3: + self.episode_end_status = "success" + # A button (1) for failure + elif event.button == 1: + self.episode_end_status = "failure" + # X button (0) for rerecord + elif event.button == 0: + self.episode_end_status = "rerecord_episode" + + # RB button (6) for closing gripper + elif event.button == 6: + self.close_gripper_command = True + + # LT button (7) for opening gripper + elif event.button == 7: + self.open_gripper_command = True + + # Reset episode status on button release + elif event.type == pygame.JOYBUTTONUP: + if event.button in [0, 2, 3]: + self.episode_end_status = None + + elif event.button == 6: + self.close_gripper_command = False + + elif event.button == 7: + self.open_gripper_command = False + + # Check for RB button (typically button 5) for intervention flag + if self.joystick.get_button(5): + self.intervention_flag = True + else: + self.intervention_flag = False + + def get_deltas(self): + """Get the current movement deltas from gamepad state.""" + import pygame + + try: + # Read joystick axes + # Left stick X and Y (typically axes 0 and 1) + x_input = self.joystick.get_axis(0) # Left/Right + y_input = self.joystick.get_axis(1) # Up/Down (often inverted) + + # Right stick Y (typically axis 3 or 4) + z_input = self.joystick.get_axis(3) # Up/Down for Z + + # Apply deadzone to avoid drift + x_input = 0 if abs(x_input) < self.deadzone else x_input + y_input = 0 if abs(y_input) < self.deadzone else y_input + z_input = 0 if abs(z_input) < self.deadzone else z_input + + # Calculate deltas (note: may need to invert axes depending on controller) + delta_x = -y_input * self.y_step_size # Forward/backward + delta_y = -x_input * self.x_step_size # Left/right + delta_z = -z_input * self.z_step_size # Up/down + + return delta_x, delta_y, delta_z + + except pygame.error: + logging.error("Error reading gamepad. Is it still connected?") + return 0.0, 0.0, 0.0 + + +class GamepadControllerHID(InputController): + """Generate motion deltas from gamepad input using HIDAPI.""" + + def __init__( + self, + x_step_size=1.0, + y_step_size=1.0, + z_step_size=1.0, + deadzone=0.1, + ): + """ + Initialize the HID gamepad controller. + + Args: + step_size: Base movement step size in meters + z_scale: Scaling factor for Z-axis movement + deadzone: Joystick deadzone to prevent drift + """ + super().__init__(x_step_size, y_step_size, z_step_size) + self.deadzone = deadzone + self.device = None + self.device_info = None + + # Movement values (normalized from -1.0 to 1.0) + self.left_x = 0.0 + self.left_y = 0.0 + self.right_x = 0.0 + self.right_y = 0.0 + + # Button states + self.buttons = {} + self.quit_requested = False + self.save_requested = False + + def find_device(self): + """Look for the gamepad device by vendor and product ID.""" + import hid + + devices = hid.enumerate() + for device in devices: + device_name = device["product_string"] + if any(controller in device_name for controller in ["Logitech", "Xbox", "PS4", "PS5"]): + return device + + logging.error( + "No gamepad found, check the connection and the product string in HID to add your gamepad" + ) + return None + + def start(self): + """Connect to the gamepad using HIDAPI.""" + import hid + + self.device_info = self.find_device() + if not self.device_info: + self.running = False + return + + try: + logging.info(f"Connecting to gamepad at path: {self.device_info['path']}") + self.device = hid.device() + self.device.open_path(self.device_info["path"]) + self.device.set_nonblocking(1) + + manufacturer = self.device.get_manufacturer_string() + product = self.device.get_product_string() + logging.info(f"Connected to {manufacturer} {product}") + + logging.info("Gamepad controls (HID mode):") + logging.info(" Left analog stick: Move in X-Y plane") + logging.info(" Right analog stick: Move in Z axis (vertical)") + logging.info(" Button 1/B/Circle: Exit") + logging.info(" Button 2/A/Cross: End episode with SUCCESS") + logging.info(" Button 3/X/Square: End episode with FAILURE") + + except OSError as e: + logging.error(f"Error opening gamepad: {e}") + logging.error("You might need to run this with sudo/admin privileges on some systems") + self.running = False + + def stop(self): + """Close the HID device connection.""" + if self.device: + self.device.close() + self.device = None + + def update(self): + """ + Read and process the latest gamepad data. + Due to an issue with the HIDAPI, we need to read the read the device several times in order to get a stable reading + """ + for _ in range(10): + self._update() + + def _update(self): + """Read and process the latest gamepad data.""" + if not self.device or not self.running: + return + + try: + # Read data from the gamepad + data = self.device.read(64) + # Interpret gamepad data - this will vary by controller model + # These offsets are for the Logitech RumblePad 2 + if data and len(data) >= 8: + # Normalize joystick values from 0-255 to -1.0-1.0 + self.left_x = (data[1] - 128) / 128.0 + self.left_y = (data[2] - 128) / 128.0 + self.right_x = (data[3] - 128) / 128.0 + self.right_y = (data[4] - 128) / 128.0 + + # Apply deadzone + self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x + self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y + self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x + self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y + + # Parse button states (byte 5 in the Logitech RumblePad 2) + buttons = data[5] + + # Check if RB is pressed then the intervention flag should be set + self.intervention_flag = data[6] in [2, 6, 10, 14] + + # Check if RT is pressed + self.open_gripper_command = data[6] in [8, 10, 12] + + # Check if LT is pressed + self.close_gripper_command = data[6] in [4, 6, 12] + + # Check if Y/Triangle button (bit 7) is pressed for saving + # Check if X/Square button (bit 5) is pressed for failure + # Check if A/Cross button (bit 4) is pressed for rerecording + if buttons & 1 << 7: + self.episode_end_status = "success" + elif buttons & 1 << 5: + self.episode_end_status = "failure" + elif buttons & 1 << 4: + self.episode_end_status = "rerecord_episode" + else: + self.episode_end_status = None + + except OSError as e: + logging.error(f"Error reading from gamepad: {e}") + + def get_deltas(self): + """Get the current movement deltas from gamepad state.""" + # Calculate deltas - invert as needed based on controller orientation + delta_x = -self.left_y * self.x_step_size # Forward/backward + delta_y = -self.left_x * self.y_step_size # Left/right + delta_z = -self.right_y * self.z_step_size # Up/down + + return delta_x, delta_y, delta_z + + def should_quit(self): + """Return True if quit button was pressed.""" + return self.quit_requested + + def should_save(self): + """Return True if save button was pressed.""" + return self.save_requested diff --git a/lerobot/common/teleoperators/gamepad/teleop_gamepad.py b/lerobot/common/teleoperators/gamepad/teleop_gamepad.py new file mode 100644 index 000000000..98a0647e2 --- /dev/null +++ b/lerobot/common/teleoperators/gamepad/teleop_gamepad.py @@ -0,0 +1,138 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from enum import IntEnum +from typing import Any + +import numpy as np + +from ..teleoperator import Teleoperator +from .configuration_gamepad import GamepadTeleopConfig + + +class GripperAction(IntEnum): + CLOSE = 0 + STAY = 1 + OPEN = 2 + + +gripper_action_map = { + "close": GripperAction.CLOSE.value, + "open": GripperAction.OPEN.value, + "stay": GripperAction.STAY.value, +} + + +class GamepadTeleop(Teleoperator): + """ + Teleop class to use gamepad inputs for control. + """ + + config_class = GamepadTeleopConfig + name = "gamepad" + + def __init__(self, config: GamepadTeleopConfig): + super().__init__(config) + self.config = config + self.robot_type = config.type + + self.gamepad = None + + @property + def action_features(self) -> dict: + if self.config.use_gripper: + return { + "dtype": "float32", + "shape": (4,), + "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3}, + } + else: + return { + "dtype": "float32", + "shape": (3,), + "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2}, + } + + @property + def feedback_features(self) -> dict: + return {} + + def connect(self) -> None: + # use HidApi for macos + if sys.platform == "darwin": + # NOTE: On macOS, pygame doesn’t reliably detect input from some controllers so we fall back to hidapi + from .gamepad_utils import GamepadControllerHID as Gamepad + else: + from .gamepad_utils import GamepadController as Gamepad + + self.gamepad = Gamepad() + self.gamepad.start() + + def get_action(self) -> dict[str, Any]: + # Update the controller to get fresh inputs + self.gamepad.update() + + # Get movement deltas from the controller + delta_x, delta_y, delta_z = self.gamepad.get_deltas() + + # Create action from gamepad input + gamepad_action = np.array([delta_x, delta_y, delta_z], dtype=np.float32) + + action_dict = { + "delta_x": gamepad_action[0], + "delta_y": gamepad_action[1], + "delta_z": gamepad_action[2], + } + + # Default gripper action is to stay + gripper_action = GripperAction.STAY.value + if self.config.use_gripper: + gripper_command = self.gamepad.gripper_command() + gripper_action = gripper_action_map[gripper_command] + action_dict["gripper"] = gripper_action + + return action_dict + + def disconnect(self) -> None: + """Disconnect from the gamepad.""" + if self.gamepad is not None: + self.gamepad.stop() + self.gamepad = None + + def is_connected(self) -> bool: + """Check if gamepad is connected.""" + return self.gamepad is not None + + def calibrate(self) -> None: + """Calibrate the gamepad.""" + # No calibration needed for gamepad + pass + + def is_calibrated(self) -> bool: + """Check if gamepad is calibrated.""" + # Gamepad doesn't require calibration + return True + + def configure(self) -> None: + """Configure the gamepad.""" + # No additional configuration needed + pass + + def send_feedback(self, feedback: dict) -> None: + """Send feedback to the gamepad.""" + # Gamepad doesn't support feedback + pass diff --git a/lerobot/common/teleoperators/keyboard/__init__.py b/lerobot/common/teleoperators/keyboard/__init__.py new file mode 100644 index 000000000..5761bf788 --- /dev/null +++ b/lerobot/common/teleoperators/keyboard/__init__.py @@ -0,0 +1,9 @@ +from .configuration_keyboard import KeyboardEndEffectorTeleopConfig, KeyboardTeleopConfig +from .teleop_keyboard import KeyboardEndEffectorTeleop, KeyboardTeleop + +__all__ = [ + "KeyboardTeleopConfig", + "KeyboardTeleop", + "KeyboardEndEffectorTeleopConfig", + "KeyboardEndEffectorTeleop", +] diff --git a/lerobot/common/teleoperators/keyboard/configuration_keyboard.py b/lerobot/common/teleoperators/keyboard/configuration_keyboard.py new file mode 100644 index 000000000..5d5ef364f --- /dev/null +++ b/lerobot/common/teleoperators/keyboard/configuration_keyboard.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("keyboard") +@dataclass +class KeyboardTeleopConfig(TeleoperatorConfig): + # TODO(Steven): Consider setting in here the keys that we want to capture/listen + mock: bool = False + + +@TeleoperatorConfig.register_subclass("keyboard_ee") +@dataclass +class KeyboardEndEffectorTeleopConfig(KeyboardTeleopConfig): + use_gripper: bool = True diff --git a/lerobot/common/teleoperators/keyboard/teleop_keyboard.py b/lerobot/common/teleoperators/keyboard/teleop_keyboard.py new file mode 100644 index 000000000..bd3ab903e --- /dev/null +++ b/lerobot/common/teleoperators/keyboard/teleop_keyboard.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +import time +from queue import Queue +from typing import Any + +from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +from ..teleoperator import Teleoperator +from .configuration_keyboard import KeyboardEndEffectorTeleopConfig, KeyboardTeleopConfig + +PYNPUT_AVAILABLE = True +try: + if ("DISPLAY" not in os.environ) and ("linux" in sys.platform): + logging.info("No DISPLAY set. Skipping pynput import.") + raise ImportError("pynput blocked intentionally due to no display.") + + from pynput import keyboard +except ImportError: + keyboard = None + PYNPUT_AVAILABLE = False +except Exception as e: + keyboard = None + PYNPUT_AVAILABLE = False + logging.info(f"Could not import pynput: {e}") + + +class KeyboardTeleop(Teleoperator): + """ + Teleop class to use keyboard inputs for control. + """ + + config_class = KeyboardTeleopConfig + name = "keyboard" + + def __init__(self, config: KeyboardTeleopConfig): + super().__init__(config) + self.config = config + self.robot_type = config.type + + self.event_queue = Queue() + self.current_pressed = {} + self.listener = None + self.logs = {} + + @property + def action_features(self) -> dict: + return { + "dtype": "float32", + "shape": (len(self.arm),), + "names": {"motors": list(self.arm.motors)}, + } + + @property + def feedback_features(self) -> dict: + return {} + + @property + def is_connected(self) -> bool: + return PYNPUT_AVAILABLE and isinstance(self.listener, keyboard.Listener) and self.listener.is_alive() + + @property + def is_calibrated(self) -> bool: + pass + + def connect(self) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError( + "Keyboard is already connected. Do not run `robot.connect()` twice." + ) + + if PYNPUT_AVAILABLE: + logging.info("pynput is available - enabling local keyboard listener.") + self.listener = keyboard.Listener( + on_press=self._on_press, + on_release=self._on_release, + ) + self.listener.start() + else: + logging.info("pynput not available - skipping local keyboard listener.") + self.listener = None + + def calibrate(self) -> None: + pass + + def _on_press(self, key): + if hasattr(key, "char"): + self.event_queue.put((key.char, True)) + + def _on_release(self, key): + if hasattr(key, "char"): + self.event_queue.put((key.char, False)) + if key == keyboard.Key.esc: + logging.info("ESC pressed, disconnecting.") + self.disconnect() + + def _drain_pressed_keys(self): + while not self.event_queue.empty(): + key_char, is_pressed = self.event_queue.get_nowait() + self.current_pressed[key_char] = is_pressed + + def configure(self): + pass + + def get_action(self) -> dict[str, Any]: + before_read_t = time.perf_counter() + + if not self.is_connected: + raise DeviceNotConnectedError( + "KeyboardTeleop is not connected. You need to run `connect()` before `get_action()`." + ) + + self._drain_pressed_keys() + + # Generate action based on current key states + action = {key for key, val in self.current_pressed.items() if val} + self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t + + return dict.fromkeys(action, None) + + def send_feedback(self, feedback: dict[str, Any]) -> None: + pass + + def disconnect(self) -> None: + if not self.is_connected: + raise DeviceNotConnectedError( + "KeyboardTeleop is not connected. You need to run `robot.connect()` before `disconnect()`." + ) + if self.listener is not None: + self.listener.stop() + + +class KeyboardEndEffectorTeleop(KeyboardTeleop): + """ + Teleop class to use keyboard inputs for end effector control. + Designed to be used with the `So100FollowerEndEffector` robot. + """ + + config_class = KeyboardEndEffectorTeleopConfig + name = "keyboard_ee" + + def __init__(self, config: KeyboardEndEffectorTeleopConfig): + super().__init__(config) + self.config = config + self.misc_keys_queue = Queue() + + @property + def action_features(self) -> dict: + if self.config.use_gripper: + return { + "dtype": "float32", + "shape": (4,), + "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2, "gripper": 3}, + } + else: + return { + "dtype": "float32", + "shape": (3,), + "names": {"delta_x": 0, "delta_y": 1, "delta_z": 2}, + } + + def _on_press(self, key): + if hasattr(key, "char"): + key = key.char + self.event_queue.put((key, True)) + + def _on_release(self, key): + if hasattr(key, "char"): + key = key.char + self.event_queue.put((key, False)) + + def get_action(self) -> dict[str, Any]: + if not self.is_connected: + raise DeviceNotConnectedError( + "KeyboardTeleop is not connected. You need to run `connect()` before `get_action()`." + ) + + self._drain_pressed_keys() + delta_x = 0.0 + delta_y = 0.0 + delta_z = 0.0 + + # Generate action based on current key states + for key, val in self.current_pressed.items(): + if key == keyboard.Key.up: + delta_x = int(val) + elif key == keyboard.Key.down: + delta_x = -int(val) + elif key == keyboard.Key.left: + delta_y = int(val) + elif key == keyboard.Key.right: + delta_y = -int(val) + elif key == keyboard.Key.shift: + delta_z = -int(val) + elif key == keyboard.Key.shift_r: + delta_z = int(val) + elif key == keyboard.Key.ctrl_r: + # Gripper actions are expected to be between 0 (close), 1 (stay), 2 (open) + gripper_action = int(val) + 1 + elif key == keyboard.Key.ctrl_l: + gripper_action = int(val) - 1 + elif val: + # If the key is pressed, add it to the misc_keys_queue + # this will record key presses that are not part of the delta_x, delta_y, delta_z + # this is useful for retrieving other events like interventions for RL, episode success, etc. + self.misc_keys_queue.put(key) + + self.current_pressed.clear() + + action_dict = { + "delta_x": delta_x, + "delta_y": delta_y, + "delta_z": delta_z, + } + + gripper_action = 1 # default gripper action is to stay + if self.config.use_gripper: + action_dict["gripper"] = gripper_action + + return action_dict diff --git a/lerobot/common/teleoperators/koch_leader/__init__.py b/lerobot/common/teleoperators/koch_leader/__init__.py new file mode 100644 index 000000000..ad2d6a0e4 --- /dev/null +++ b/lerobot/common/teleoperators/koch_leader/__init__.py @@ -0,0 +1,2 @@ +from .config_koch_leader import KochLeaderConfig +from .koch_leader import KochLeader diff --git a/lerobot/common/teleoperators/koch_leader/config_koch_leader.py b/lerobot/common/teleoperators/koch_leader/config_koch_leader.py new file mode 100644 index 000000000..64aaae123 --- /dev/null +++ b/lerobot/common/teleoperators/koch_leader/config_koch_leader.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("koch_leader") +@dataclass +class KochLeaderConfig(TeleoperatorConfig): + # Port to connect to the arm + port: str + + # Sets the arm in torque mode with the gripper motor set to this value. This makes it possible to squeeze + # the gripper and have it spring back to an open position on its own. + gripper_open_pos: float = 50.0 diff --git a/lerobot/common/teleoperators/koch_leader/koch_leader.py b/lerobot/common/teleoperators/koch_leader/koch_leader.py new file mode 100644 index 000000000..820acc87c --- /dev/null +++ b/lerobot/common/teleoperators/koch_leader/koch_leader.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time + +from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.common.motors.dynamixel import ( + DriveMode, + DynamixelMotorsBus, + OperatingMode, +) + +from ..teleoperator import Teleoperator +from .config_koch_leader import KochLeaderConfig + +logger = logging.getLogger(__name__) + + +class KochLeader(Teleoperator): + """ + - [Koch v1.0](https://github.com/AlexanderKoch-Koch/low_cost_robot), with and without the wrist-to-elbow + expansion, developed by Alexander Koch from [Tau Robotics](https://tau-robotics.com) + - [Koch v1.1](https://github.com/jess-moss/koch-v1-1) developed by Jess Moss + """ + + config_class = KochLeaderConfig + name = "koch_leader" + + def __init__(self, config: KochLeaderConfig): + super().__init__(config) + self.config = config + self.bus = DynamixelMotorsBus( + port=self.config.port, + motors={ + "shoulder_pan": Motor(1, "xl330-m077", MotorNormMode.RANGE_M100_100), + "shoulder_lift": Motor(2, "xl330-m077", MotorNormMode.RANGE_M100_100), + "elbow_flex": Motor(3, "xl330-m077", MotorNormMode.RANGE_M100_100), + "wrist_flex": Motor(4, "xl330-m077", MotorNormMode.RANGE_M100_100), + "wrist_roll": Motor(5, "xl330-m077", MotorNormMode.RANGE_M100_100), + "gripper": Motor(6, "xl330-m077", MotorNormMode.RANGE_0_100), + }, + calibration=self.calibration, + ) + + @property + def action_features(self) -> dict[str, type]: + return {f"{motor}.pos": float for motor in self.bus.motors} + + @property + def feedback_features(self) -> dict[str, type]: + return {} + + @property + def is_connected(self) -> bool: + return self.bus.is_connected + + def connect(self, calibrate: bool = True) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + self.bus.connect() + if not self.is_calibrated and calibrate: + self.calibrate() + + self.configure() + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + return self.bus.is_calibrated + + def calibrate(self) -> None: + logger.info(f"\nRunning calibration of {self}") + self.bus.disable_torque() + for motor in self.bus.motors: + self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value) + + self.bus.write("Drive_Mode", "elbow_flex", DriveMode.INVERTED.value) + drive_modes = {motor: 1 if motor == "elbow_flex" else 0 for motor in self.bus.motors} + + input(f"Move {self} to the middle of its range of motion and press ENTER....") + homing_offsets = self.bus.set_half_turn_homings() + + full_turn_motors = ["shoulder_pan", "wrist_roll"] + unknown_range_motors = [motor for motor in self.bus.motors if motor not in full_turn_motors] + print( + f"Move all joints except {full_turn_motors} sequentially through their " + "entire ranges of motion.\nRecording positions. Press ENTER to stop..." + ) + range_mins, range_maxes = self.bus.record_ranges_of_motion(unknown_range_motors) + for motor in full_turn_motors: + range_mins[motor] = 0 + range_maxes[motor] = 4095 + + self.calibration = {} + for motor, m in self.bus.motors.items(): + self.calibration[motor] = MotorCalibration( + id=m.id, + drive_mode=drive_modes[motor], + homing_offset=homing_offsets[motor], + range_min=range_mins[motor], + range_max=range_maxes[motor], + ) + + self.bus.write_calibration(self.calibration) + self._save_calibration() + logger.info(f"Calibration saved to {self.calibration_fpath}") + + def configure(self) -> None: + self.bus.disable_torque() + self.bus.configure_motors() + for motor in self.bus.motors: + if motor != "gripper": + # Use 'extended position mode' for all motors except gripper, because in joint mode the servos + # can't rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while + # assembling the arm, you could end up with a servo with a position 0 or 4095 at a crucial + # point + self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value) + + # Use 'position control current based' for gripper to be limited by the limit of the current. + # For the follower gripper, it means it can grasp an object without forcing too much even tho, + # its goal position is a complete grasp (both gripper fingers are ordered to join and reach a touch). + # For the leader gripper, it means we can use it as a physical trigger, since we can force with our finger + # to make it move, and it will move back to its original target position when we release the force. + self.bus.write("Operating_Mode", "gripper", OperatingMode.CURRENT_POSITION.value) + # Set gripper's goal pos in current position mode so that we can use it as a trigger. + self.bus.enable_torque("gripper") + if self.is_calibrated: + self.bus.write("Goal_Position", "gripper", self.config.gripper_open_pos) + + def setup_motors(self) -> None: + for motor in reversed(self.bus.motors): + input(f"Connect the controller board to the '{motor}' motor only and press enter.") + self.bus.setup_motor(motor) + print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + + def get_action(self) -> dict[str, float]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + start = time.perf_counter() + action = self.bus.sync_read("Present_Position") + action = {f"{motor}.pos": val for motor, val in action.items()} + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read action: {dt_ms:.1f}ms") + return action + + def send_feedback(self, feedback: dict[str, float]) -> None: + # TODO(rcadene, aliberts): Implement force feedback + raise NotImplementedError + + def disconnect(self) -> None: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + self.bus.disconnect() + logger.info(f"{self} disconnected.") diff --git a/lerobot/common/teleoperators/so100_leader/__init__.py b/lerobot/common/teleoperators/so100_leader/__init__.py new file mode 100644 index 000000000..63c877e60 --- /dev/null +++ b/lerobot/common/teleoperators/so100_leader/__init__.py @@ -0,0 +1,2 @@ +from .config_so100_leader import SO100LeaderConfig +from .so100_leader import SO100Leader diff --git a/lerobot/common/teleoperators/so100_leader/config_so100_leader.py b/lerobot/common/teleoperators/so100_leader/config_so100_leader.py new file mode 100644 index 000000000..a97949b7e --- /dev/null +++ b/lerobot/common/teleoperators/so100_leader/config_so100_leader.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("so100_leader") +@dataclass +class SO100LeaderConfig(TeleoperatorConfig): + # Port to connect to the arm + port: str diff --git a/lerobot/common/teleoperators/so100_leader/so100_leader.py b/lerobot/common/teleoperators/so100_leader/so100_leader.py new file mode 100644 index 000000000..59b083e3f --- /dev/null +++ b/lerobot/common/teleoperators/so100_leader/so100_leader.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time + +from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.common.motors.feetech import ( + FeetechMotorsBus, + OperatingMode, +) + +from ..teleoperator import Teleoperator +from .config_so100_leader import SO100LeaderConfig + +logger = logging.getLogger(__name__) + + +class SO100Leader(Teleoperator): + """ + [SO-100 Leader Arm](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio + """ + + config_class = SO100LeaderConfig + name = "so100_leader" + + def __init__(self, config: SO100LeaderConfig): + super().__init__(config) + self.config = config + self.bus = FeetechMotorsBus( + port=self.config.port, + motors={ + "shoulder_pan": Motor(1, "sts3215", MotorNormMode.RANGE_M100_100), + "shoulder_lift": Motor(2, "sts3215", MotorNormMode.RANGE_M100_100), + "elbow_flex": Motor(3, "sts3215", MotorNormMode.RANGE_M100_100), + "wrist_flex": Motor(4, "sts3215", MotorNormMode.RANGE_M100_100), + "wrist_roll": Motor(5, "sts3215", MotorNormMode.RANGE_M100_100), + "gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100), + }, + calibration=self.calibration, + ) + + @property + def action_features(self) -> dict[str, type]: + return {f"{motor}.pos": float for motor in self.bus.motors} + + @property + def feedback_features(self) -> dict[str, type]: + return {} + + @property + def is_connected(self) -> bool: + return self.bus.is_connected + + def connect(self, calibrate: bool = True) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + self.bus.connect() + if not self.is_calibrated and calibrate: + self.calibrate() + + self.configure() + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + return self.bus.is_calibrated + + def calibrate(self) -> None: + logger.info(f"\nRunning calibration of {self}") + self.bus.disable_torque() + for motor in self.bus.motors: + self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value) + + input(f"Move {self} to the middle of its range of motion and press ENTER....") + homing_offsets = self.bus.set_half_turn_homings() + + full_turn_motor = "wrist_roll" + unknown_range_motors = [motor for motor in self.bus.motors if motor != full_turn_motor] + print( + f"Move all joints except '{full_turn_motor}' sequentially through their " + "entire ranges of motion.\nRecording positions. Press ENTER to stop..." + ) + range_mins, range_maxes = self.bus.record_ranges_of_motion(unknown_range_motors) + range_mins[full_turn_motor] = 0 + range_maxes[full_turn_motor] = 4095 + + self.calibration = {} + for motor, m in self.bus.motors.items(): + self.calibration[motor] = MotorCalibration( + id=m.id, + drive_mode=0, + homing_offset=homing_offsets[motor], + range_min=range_mins[motor], + range_max=range_maxes[motor], + ) + + self.bus.write_calibration(self.calibration) + self._save_calibration() + print(f"Calibration saved to {self.calibration_fpath}") + + def configure(self) -> None: + self.bus.disable_torque() + self.bus.configure_motors() + for motor in self.bus.motors: + self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value) + + def setup_motors(self) -> None: + for motor in reversed(self.bus.motors): + input(f"Connect the controller board to the '{motor}' motor only and press enter.") + self.bus.setup_motor(motor) + print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + + def get_action(self) -> dict[str, float]: + start = time.perf_counter() + action = self.bus.sync_read("Present_Position") + action = {f"{motor}.pos": val for motor, val in action.items()} + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read action: {dt_ms:.1f}ms") + return action + + def send_feedback(self, feedback: dict[str, float]) -> None: + # TODO(rcadene, aliberts): Implement force feedback + raise NotImplementedError + + def disconnect(self) -> None: + if not self.is_connected: + DeviceNotConnectedError(f"{self} is not connected.") + + self.bus.disconnect() + logger.info(f"{self} disconnected.") diff --git a/lerobot/common/teleoperators/so101_leader/__init__.py b/lerobot/common/teleoperators/so101_leader/__init__.py new file mode 100644 index 000000000..1f45170e9 --- /dev/null +++ b/lerobot/common/teleoperators/so101_leader/__init__.py @@ -0,0 +1,2 @@ +from .config_so101_leader import SO101LeaderConfig +from .so101_leader import SO101Leader diff --git a/lerobot/common/teleoperators/so101_leader/config_so101_leader.py b/lerobot/common/teleoperators/so101_leader/config_so101_leader.py new file mode 100644 index 000000000..8d91c32df --- /dev/null +++ b/lerobot/common/teleoperators/so101_leader/config_so101_leader.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("so101_leader") +@dataclass +class SO101LeaderConfig(TeleoperatorConfig): + # Port to connect to the arm + port: str + + use_degrees: bool = False diff --git a/lerobot/common/teleoperators/so101_leader/so101_leader.py b/lerobot/common/teleoperators/so101_leader/so101_leader.py new file mode 100644 index 000000000..80ddfbb1d --- /dev/null +++ b/lerobot/common/teleoperators/so101_leader/so101_leader.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time + +from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.common.motors.feetech import ( + FeetechMotorsBus, + OperatingMode, +) + +from ..teleoperator import Teleoperator +from .config_so101_leader import SO101LeaderConfig + +logger = logging.getLogger(__name__) + + +class SO101Leader(Teleoperator): + """ + SO-101 Leader Arm designed by TheRobotStudio and Hugging Face. + """ + + config_class = SO101LeaderConfig + name = "so101_leader" + + def __init__(self, config: SO101LeaderConfig): + super().__init__(config) + self.config = config + norm_mode_body = MotorNormMode.DEGREES if config.use_degrees else MotorNormMode.RANGE_M100_100 + self.bus = FeetechMotorsBus( + port=self.config.port, + motors={ + "shoulder_pan": Motor(1, "sts3215", norm_mode_body), + "shoulder_lift": Motor(2, "sts3215", norm_mode_body), + "elbow_flex": Motor(3, "sts3215", norm_mode_body), + "wrist_flex": Motor(4, "sts3215", norm_mode_body), + "wrist_roll": Motor(5, "sts3215", norm_mode_body), + "gripper": Motor(6, "sts3215", MotorNormMode.RANGE_0_100), + }, + calibration=self.calibration, + ) + + @property + def action_features(self) -> dict[str, type]: + return {f"{motor}.pos": float for motor in self.bus.motors} + + @property + def feedback_features(self) -> dict[str, type]: + return {} + + @property + def is_connected(self) -> bool: + return self.bus.is_connected + + def connect(self, calibrate: bool = True) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + self.bus.connect() + if not self.is_calibrated and calibrate: + self.calibrate() + + self.configure() + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + return self.bus.is_calibrated + + def calibrate(self) -> None: + logger.info(f"\nRunning calibration of {self}") + self.bus.disable_torque() + for motor in self.bus.motors: + self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value) + + input(f"Move {self} to the middle of its range of motion and press ENTER....") + homing_offsets = self.bus.set_half_turn_homings() + + print( + "Move all joints sequentially through their entire ranges " + "of motion.\nRecording positions. Press ENTER to stop..." + ) + range_mins, range_maxes = self.bus.record_ranges_of_motion() + + self.calibration = {} + for motor, m in self.bus.motors.items(): + self.calibration[motor] = MotorCalibration( + id=m.id, + drive_mode=0, + homing_offset=homing_offsets[motor], + range_min=range_mins[motor], + range_max=range_maxes[motor], + ) + + self.bus.write_calibration(self.calibration) + self._save_calibration() + print(f"Calibration saved to {self.calibration_fpath}") + + def configure(self) -> None: + self.bus.disable_torque() + self.bus.configure_motors() + for motor in self.bus.motors: + self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value) + + def setup_motors(self) -> None: + for motor in reversed(self.bus.motors): + input(f"Connect the controller board to the '{motor}' motor only and press enter.") + self.bus.setup_motor(motor) + print(f"'{motor}' motor id set to {self.bus.motors[motor].id}") + + def get_action(self) -> dict[str, float]: + start = time.perf_counter() + action = self.bus.sync_read("Present_Position") + action = {f"{motor}.pos": val for motor, val in action.items()} + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read action: {dt_ms:.1f}ms") + return action + + def send_feedback(self, feedback: dict[str, float]) -> None: + # TODO(rcadene, aliberts): Implement force feedback + raise NotImplementedError + + def disconnect(self) -> None: + if not self.is_connected: + DeviceNotConnectedError(f"{self} is not connected.") + + self.bus.disconnect() + logger.info(f"{self} disconnected.") diff --git a/lerobot/common/teleoperators/stretch3_gamepad/__init__.py b/lerobot/common/teleoperators/stretch3_gamepad/__init__.py new file mode 100644 index 000000000..ac45b6dd4 --- /dev/null +++ b/lerobot/common/teleoperators/stretch3_gamepad/__init__.py @@ -0,0 +1,2 @@ +from .configuration_stretch3 import Stretch3GamePadConfig +from .stretch3_gamepad import Stretch3GamePad diff --git a/lerobot/common/teleoperators/stretch3_gamepad/configuration_stretch3.py b/lerobot/common/teleoperators/stretch3_gamepad/configuration_stretch3.py new file mode 100644 index 000000000..507a21589 --- /dev/null +++ b/lerobot/common/teleoperators/stretch3_gamepad/configuration_stretch3.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("stretch3") +@dataclass +class Stretch3GamePadConfig(TeleoperatorConfig): + mock: bool = False diff --git a/lerobot/common/teleoperators/stretch3_gamepad/stretch3_gamepad.py b/lerobot/common/teleoperators/stretch3_gamepad/stretch3_gamepad.py new file mode 100644 index 000000000..1e9768c7e --- /dev/null +++ b/lerobot/common/teleoperators/stretch3_gamepad/stretch3_gamepad.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import numpy as np +from stretch_body.gamepad_teleop import GamePadTeleop +from stretch_body.robot_params import RobotParams + +from lerobot.common.errors import DeviceAlreadyConnectedError + +from ..teleoperator import Teleoperator +from .configuration_stretch3 import Stretch3GamePadConfig + +# from stretch_body.gamepad_controller.GamePadController +GAMEPAD_BUTTONS = [ + "middle_led_ring_button_pressed", + "left_stick_x", + "left_stick_y", + "right_stick_x", + "right_stick_y", + "left_stick_button_pressed", + "right_stick_button_pressed", + "bottom_button_pressed", + "top_button_pressed", + "left_button_pressed", + "right_button_pressed", + "left_shoulder_button_pressed", + "right_shoulder_button_pressed", + "select_button_pressed", + "start_button_pressed", + "left_trigger_pulled", + "right_trigger_pulled", + "bottom_pad_pressed", + "top_pad_pressed", + "left_pad_pressed", + "right_pad_pressed", +] + + +class Stretch3GamePad(Teleoperator): + """[Stretch 3](https://hello-robot.com/stretch-3-product), by Hello Robot.""" + + config_class = Stretch3GamePadConfig + name = "stretch3" + + def __init__(self, config: Stretch3GamePadConfig): + raise NotImplementedError + super().__init__(config) + + self.config = config + self.robot_type = self.config.type + + self.api = GamePadTeleop(robot_instance=False) + + self.is_connected = False + self.logs = {} + + # TODO(aliberts): test this + RobotParams.set_logging_level("WARNING") + RobotParams.set_logging_formatter("brief_console_formatter") + + @property + def action_features(self) -> dict: + return { + "dtype": "float32", + "shape": (len(GAMEPAD_BUTTONS),), + "names": {"buttons": GAMEPAD_BUTTONS}, + } + + @property + def feedback_features(self) -> dict: + return {} + + def connect(self) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError( + "ManipulatorRobot is already connected. Do not run `robot.connect()` twice." + ) + + self.api.startup() + self.api._update_state() # Check controller can be read & written + self.api._update_modes() + self.is_connected = True + + def calibrate(self) -> None: + pass + + def get_action(self) -> np.ndarray: + # Read Stretch state + before_read_t = time.perf_counter() + action = self.api.gamepad_controller.get_state() + self.logs["read_pos_dt_s"] = time.perf_counter() - before_read_t + + action = np.asarray(list(action.values())) + + return action + + def send_feedback(self, feedback: np.ndarray) -> None: + pass + + def print_logs(self) -> None: + pass + # TODO(aliberts): move robot-specific logs logic here + + def disconnect(self) -> None: + self.api.stop() + self.is_connected = False diff --git a/lerobot/common/teleoperators/teleoperator.py b/lerobot/common/teleoperators/teleoperator.py new file mode 100644 index 000000000..6a20a3a8a --- /dev/null +++ b/lerobot/common/teleoperators/teleoperator.py @@ -0,0 +1,180 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from pathlib import Path +from typing import Any, Type + +import draccus + +from lerobot.common.constants import HF_LEROBOT_CALIBRATION, TELEOPERATORS +from lerobot.common.motors.motors_bus import MotorCalibration + +from .config import TeleoperatorConfig + + +class Teleoperator(abc.ABC): + """ + The base abstract class for all LeRobot-compatible teleoperation devices. + + This class provides a standardized interface for interacting with physical teleoperators. + Subclasses must implement all abstract methods and properties to be usable. + + Attributes: + config_class (RobotConfig): The expected configuration class for this teleoperator. + name (str): The unique name used to identify this teleoperator type. + """ + + # Set these in ALL subclasses + config_class: Type[TeleoperatorConfig] + name: str + + def __init__(self, config: TeleoperatorConfig): + self.id = config.id + self.calibration_dir = ( + config.calibration_dir + if config.calibration_dir + else HF_LEROBOT_CALIBRATION / TELEOPERATORS / self.name + ) + self.calibration_dir.mkdir(parents=True, exist_ok=True) + self.calibration_fpath = self.calibration_dir / f"{self.id}.json" + self.calibration: dict[str, MotorCalibration] = {} + if self.calibration_fpath.is_file(): + self._load_calibration() + + def __str__(self) -> str: + return f"{self.id} {self.__class__.__name__}" + + @property + @abc.abstractmethod + def action_features(self) -> dict: + """ + A dictionary describing the structure and types of the actions produced by the teleoperator. Its + structure (keys) should match the structure of what is returned by :pymeth:`get_action`. Values for + the dict should be the type of the value if it's a simple value, e.g. `float` for single + proprioceptive value (a joint's goal position/velocity) + + Note: this property should be able to be called regardless of whether the robot is connected or not. + """ + pass + + @property + @abc.abstractmethod + def feedback_features(self) -> dict: + """ + A dictionary describing the structure and types of the feedback actions expected by the robot. Its + structure (keys) should match the structure of what is passed to :pymeth:`send_feedback`. Values for + the dict should be the type of the value if it's a simple value, e.g. `float` for single + proprioceptive value (a joint's goal position/velocity) + + Note: this property should be able to be called regardless of whether the robot is connected or not. + """ + pass + + @property + @abc.abstractmethod + def is_connected(self) -> bool: + """ + Whether the teleoperator is currently connected or not. If `False`, calling :pymeth:`get_action` + or :pymeth:`send_feedback` should raise an error. + """ + pass + + @abc.abstractmethod + def connect(self, calibrate: bool = True) -> None: + """ + Establish communication with the teleoperator. + + Args: + calibrate (bool): If True, automatically calibrate the teleoperator after connecting if it's not + calibrated or needs calibration (this is hardware-dependant). + """ + pass + + @property + @abc.abstractmethod + def is_calibrated(self) -> bool: + """Whether the teleoperator is currently calibrated or not. Should be always `True` if not applicable""" + pass + + @abc.abstractmethod + def calibrate(self) -> None: + """ + Calibrate the teleoperator if applicable. If not, this should be a no-op. + + This method should collect any necessary data (e.g., motor offsets) and update the + :pyattr:`calibration` dictionary accordingly. + """ + pass + + def _load_calibration(self, fpath: Path | None = None) -> None: + """ + Helper to load calibration data from the specified file. + + Args: + fpath (Path | None): Optional path to the calibration file. Defaults to `self.calibration_fpath`. + """ + fpath = self.calibration_fpath if fpath is None else fpath + with open(fpath) as f, draccus.config_type("json"): + self.calibration = draccus.load(dict[str, MotorCalibration], f) + + def _save_calibration(self, fpath: Path | None = None) -> None: + """ + Helper to save calibration data to the specified file. + + Args: + fpath (Path | None): Optional path to save the calibration file. Defaults to `self.calibration_fpath`. + """ + fpath = self.calibration_fpath if fpath is None else fpath + with open(fpath, "w") as f, draccus.config_type("json"): + draccus.dump(self.calibration, f, indent=4) + + @abc.abstractmethod + def configure(self) -> None: + """ + Apply any one-time or runtime configuration to the teleoperator. + This may include setting motor parameters, control modes, or initial state. + """ + pass + + @abc.abstractmethod + def get_action(self) -> dict[str, Any]: + """ + Retrieve the current action from the teleoperator. + + Returns: + dict[str, Any]: A flat dictionary representing the teleoperator's current actions. Its + structure should match :pymeth:`observation_features`. + """ + pass + + @abc.abstractmethod + def send_feedback(self, feedback: dict[str, Any]) -> None: + """ + Send a feedback action command to the teleoperator. + + Args: + feedback (dict[str, Any]): Dictionary representing the desired feedback. Its structure should match + :pymeth:`feedback_features`. + + Returns: + dict[str, Any]: The action actually sent to the motors potentially clipped or modified, e.g. by + safety limits on velocity. + """ + pass + + @abc.abstractmethod + def disconnect(self) -> None: + """Disconnect from the teleoperator and perform any necessary cleanup.""" + pass diff --git a/lerobot/common/teleoperators/utils.py b/lerobot/common/teleoperators/utils.py new file mode 100644 index 000000000..b49addc15 --- /dev/null +++ b/lerobot/common/teleoperators/utils.py @@ -0,0 +1,57 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .config import TeleoperatorConfig +from .teleoperator import Teleoperator + + +def make_teleoperator_from_config(config: TeleoperatorConfig) -> Teleoperator: + if config.type == "keyboard": + from .keyboard import KeyboardTeleop + + return KeyboardTeleop(config) + elif config.type == "koch_leader": + from .koch_leader import KochLeader + + return KochLeader(config) + elif config.type == "so100_leader": + from .so100_leader import SO100Leader + + return SO100Leader(config) + elif config.type == "so101_leader": + from .so101_leader import SO101Leader + + return SO101Leader(config) + elif config.type == "stretch3": + from .stretch3_gamepad import Stretch3GamePad + + return Stretch3GamePad(config) + elif config.type == "widowx": + from .widowx import WidowX + + return WidowX(config) + elif config.type == "mock_teleop": + from tests.mocks.mock_teleop import MockTeleop + + return MockTeleop(config) + elif config.type == "gamepad": + from .gamepad.teleop_gamepad import GamepadTeleop + + return GamepadTeleop(config) + elif config.type == "keyboard_ee": + from .keyboard.teleop_keyboard import KeyboardEndEffectorTeleop + + return KeyboardEndEffectorTeleop(config) + else: + raise ValueError(config.type) diff --git a/lerobot/common/teleoperators/widowx/__init__.py b/lerobot/common/teleoperators/widowx/__init__.py new file mode 100644 index 000000000..122ee3290 --- /dev/null +++ b/lerobot/common/teleoperators/widowx/__init__.py @@ -0,0 +1,2 @@ +from .config_widowx import WidowXConfig +from .widowx import WidowX diff --git a/lerobot/common/teleoperators/widowx/config_widowx.py b/lerobot/common/teleoperators/widowx/config_widowx.py new file mode 100644 index 000000000..42fae12db --- /dev/null +++ b/lerobot/common/teleoperators/widowx/config_widowx.py @@ -0,0 +1,25 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from ..config import TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("widowx") +@dataclass +class WidowXConfig(TeleoperatorConfig): + port: str # Port to connect to the arm diff --git a/lerobot/common/teleoperators/widowx/widowx.py b/lerobot/common/teleoperators/widowx/widowx.py new file mode 100644 index 000000000..8a42c9063 --- /dev/null +++ b/lerobot/common/teleoperators/widowx/widowx.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time + +from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.common.motors.dynamixel import ( + DriveMode, + DynamixelMotorsBus, + OperatingMode, +) + +from ..teleoperator import Teleoperator +from .config_widowx import WidowXConfig + +logger = logging.getLogger(__name__) + + +class WidowX(Teleoperator): + """ + [WidowX](https://www.trossenrobotics.com/widowx-250) developed by Trossen Robotics + """ + + config_class = WidowXConfig + name = "widowx" + + def __init__(self, config: WidowXConfig): + raise NotImplementedError + super().__init__(config) + self.config = config + self.bus = DynamixelMotorsBus( + port=self.config.port, + motors={ + "waist": Motor(1, "xm430-w350", MotorNormMode.RANGE_M100_100), + "shoulder": Motor(2, "xm430-w350", MotorNormMode.RANGE_M100_100), + "shoulder_shadow": Motor(3, "xm430-w350", MotorNormMode.RANGE_M100_100), + "elbow": Motor(4, "xm430-w350", MotorNormMode.RANGE_M100_100), + "elbow_shadow": Motor(5, "xm430-w350", MotorNormMode.RANGE_M100_100), + "forearm_roll": Motor(6, "xm430-w350", MotorNormMode.RANGE_M100_100), + "wrist_angle": Motor(7, "xm430-w350", MotorNormMode.RANGE_M100_100), + "wrist_rotate": Motor(8, "xl430-w250", MotorNormMode.RANGE_M100_100), + "gripper": Motor(9, "xc430-w150", MotorNormMode.RANGE_0_100), + }, + ) + + @property + def action_features(self) -> dict[str, type]: + return {f"{motor}.pos": float for motor in self.bus.motors} + + @property + def feedback_features(self) -> dict[str, type]: + return {} + + @property + def is_connected(self) -> bool: + return self.bus.is_connected + + def connect(self, calibrate: bool = True): + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + self.bus.connect() + if not self.is_calibrated and calibrate: + self.calibrate() + + self.configure() + logger.info(f"{self} connected.") + + @property + def is_calibrated(self) -> bool: + return self.bus.is_calibrated + + def calibrate(self) -> None: + raise NotImplementedError # TODO(aliberts): adapt code below (copied from koch) + logger.info(f"\nRunning calibration of {self}") + self.bus.disable_torque() + for motor in self.bus.motors: + self.bus.write("Operating_Mode", motor, OperatingMode.EXTENDED_POSITION.value) + + self.bus.write("Drive_Mode", "elbow_flex", DriveMode.INVERTED.value) + drive_modes = {motor: 1 if motor == "elbow_flex" else 0 for motor in self.bus.motors} + + input("Move robot to the middle of its range of motion and press ENTER....") + homing_offsets = self.bus.set_half_turn_homings() + + full_turn_motors = ["shoulder_pan", "wrist_roll"] + unknown_range_motors = [motor for motor in self.bus.motors if motor not in full_turn_motors] + print( + f"Move all joints except {full_turn_motors} sequentially through their " + "entire ranges of motion.\nRecording positions. Press ENTER to stop..." + ) + range_mins, range_maxes = self.bus.record_ranges_of_motion(unknown_range_motors) + for motor in full_turn_motors: + range_mins[motor] = 0 + range_maxes[motor] = 4095 + + self.calibration = {} + for motor, m in self.bus.motors.items(): + self.calibration[motor] = MotorCalibration( + id=m.id, + drive_mode=drive_modes[motor], + homing_offset=homing_offsets[motor], + range_min=range_mins[motor], + range_max=range_maxes[motor], + ) + + self.bus.write_calibration(self.calibration) + self._save_calibration() + logger.info(f"Calibration saved to {self.calibration_fpath}") + + def configure(self) -> None: + self.bus.disable_torque() + self.bus.configure_motors() + + # Set secondary/shadow ID for shoulder and elbow. These joints have two motors. + # As a result, if only one of them is required to move to a certain position, + # the other will follow. This is to avoid breaking the motors. + self.bus.write("Secondary_ID", "shoulder_shadow", 2) + self.bus.write("Secondary_ID", "elbow_shadow", 4) + + def get_action(self) -> dict[str, float]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + start = time.perf_counter() + action = self.bus.sync_read("Present_Position") + action = {f"{motor}.pos": val for motor, val in action.items()} + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read action: {dt_ms:.1f}ms") + return action + + def send_feedback(self, feedback: dict[str, float]) -> None: + raise NotImplementedError + + def disconnect(self) -> None: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + self.bus.disconnect() + logger.info(f"{self} disconnected.") diff --git a/lerobot/common/transport/services.proto b/lerobot/common/transport/services.proto new file mode 100644 index 000000000..29d00005a --- /dev/null +++ b/lerobot/common/transport/services.proto @@ -0,0 +1,59 @@ +// Copyright 2024 The HuggingFace Inc. team. +// All rights reserved. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// To generate a classes for transport part (services_pb2.py and services_pb2_grpc.py) use the following command: +// +// python -m grpc_tools.protoc -I . --python_out=. --grpc_python_out=. lerobot/common/transport/services.proto +// +// The command should be launched from the root of the project. + +syntax = "proto3"; + +package transport; + +// LearnerService: the Actor calls this to push transitions. +// The Learner implements this service. +service LearnerService { + // Actor -> Learner to store transitions + rpc StreamParameters(Empty) returns (stream Parameters); + rpc SendTransitions(stream Transition) returns (Empty); + rpc SendInteractions(stream InteractionMessage) returns (Empty); + rpc Ready(Empty) returns (Empty); +} + +enum TransferState { + TRANSFER_UNKNOWN = 0; + TRANSFER_BEGIN = 1; + TRANSFER_MIDDLE = 2; + TRANSFER_END = 3; +} + +// Messages +message Transition { + TransferState transfer_state = 1; + bytes data = 2; +} + +message Parameters { + TransferState transfer_state = 1; + bytes data = 2; +} + +message InteractionMessage { + TransferState transfer_state = 1; + bytes data = 2; +} + +message Empty {} diff --git a/lerobot/common/transport/services_pb2.py b/lerobot/common/transport/services_pb2.py new file mode 100644 index 000000000..727beb60d --- /dev/null +++ b/lerobot/common/transport/services_pb2.py @@ -0,0 +1,45 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: lerobot/common/transport/services.proto +# Protobuf Python Version: 5.29.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 29, + 0, + '', + 'lerobot/common/transport/services.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\'lerobot/common/transport/services.proto\x12\ttransport\"L\n\nTransition\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"L\n\nParameters\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"T\n\x12InteractionMessage\x12\x30\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x18.transport.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\x81\x02\n\x0eLearnerService\x12=\n\x10StreamParameters\x12\x10.transport.Empty\x1a\x15.transport.Parameters0\x01\x12<\n\x0fSendTransitions\x12\x15.transport.Transition\x1a\x10.transport.Empty(\x01\x12\x45\n\x10SendInteractions\x12\x1d.transport.InteractionMessage\x1a\x10.transport.Empty(\x01\x12+\n\x05Ready\x12\x10.transport.Empty\x1a\x10.transport.Emptyb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'lerobot.common.transport.services_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_TRANSFERSTATE']._serialized_start=305 + _globals['_TRANSFERSTATE']._serialized_end=401 + _globals['_TRANSITION']._serialized_start=54 + _globals['_TRANSITION']._serialized_end=130 + _globals['_PARAMETERS']._serialized_start=132 + _globals['_PARAMETERS']._serialized_end=208 + _globals['_INTERACTIONMESSAGE']._serialized_start=210 + _globals['_INTERACTIONMESSAGE']._serialized_end=294 + _globals['_EMPTY']._serialized_start=296 + _globals['_EMPTY']._serialized_end=303 + _globals['_LEARNERSERVICE']._serialized_start=404 + _globals['_LEARNERSERVICE']._serialized_end=661 +# @@protoc_insertion_point(module_scope) diff --git a/lerobot/common/transport/services_pb2_grpc.py b/lerobot/common/transport/services_pb2_grpc.py new file mode 100644 index 000000000..5a7a924fd --- /dev/null +++ b/lerobot/common/transport/services_pb2_grpc.py @@ -0,0 +1,233 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +from lerobot.common.transport import services_pb2 as lerobot_dot_common_dot_transport_dot_services__pb2 + +GRPC_GENERATED_VERSION = '1.71.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in lerobot/common/transport/services_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) + + +class LearnerServiceStub: + """LearnerService: the Actor calls this to push transitions. + The Learner implements this service. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.StreamParameters = channel.unary_stream( + '/transport.LearnerService/StreamParameters', + request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.FromString, + _registered_method=True) + self.SendTransitions = channel.stream_unary( + '/transport.LearnerService/SendTransitions', + request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Transition.SerializeToString, + response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + _registered_method=True) + self.SendInteractions = channel.stream_unary( + '/transport.LearnerService/SendInteractions', + request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString, + response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + _registered_method=True) + self.Ready = channel.unary_unary( + '/transport.LearnerService/Ready', + request_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + response_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + _registered_method=True) + + +class LearnerServiceServicer: + """LearnerService: the Actor calls this to push transitions. + The Learner implements this service. + """ + + def StreamParameters(self, request, context): + """Actor -> Learner to store transitions + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendTransitions(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendInteractions(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Ready(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_LearnerServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'StreamParameters': grpc.unary_stream_rpc_method_handler( + servicer.StreamParameters, + request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.SerializeToString, + ), + 'SendTransitions': grpc.stream_unary_rpc_method_handler( + servicer.SendTransitions, + request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Transition.FromString, + response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + ), + 'SendInteractions': grpc.stream_unary_rpc_method_handler( + servicer.SendInteractions, + request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.FromString, + response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + ), + 'Ready': grpc.unary_unary_rpc_method_handler( + servicer.Ready, + request_deserializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + response_serializer=lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'transport.LearnerService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('transport.LearnerService', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class LearnerService: + """LearnerService: the Actor calls this to push transitions. + The Learner implements this service. + """ + + @staticmethod + def StreamParameters(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/transport.LearnerService/StreamParameters', + lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + lerobot_dot_common_dot_transport_dot_services__pb2.Parameters.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SendTransitions(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary( + request_iterator, + target, + '/transport.LearnerService/SendTransitions', + lerobot_dot_common_dot_transport_dot_services__pb2.Transition.SerializeToString, + lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SendInteractions(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary( + request_iterator, + target, + '/transport.LearnerService/SendInteractions', + lerobot_dot_common_dot_transport_dot_services__pb2.InteractionMessage.SerializeToString, + lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def Ready(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/transport.LearnerService/Ready', + lerobot_dot_common_dot_transport_dot_services__pb2.Empty.SerializeToString, + lerobot_dot_common_dot_transport_dot_services__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/lerobot/common/transport/utils.py b/lerobot/common/transport/utils.py new file mode 100644 index 000000000..774721fc6 --- /dev/null +++ b/lerobot/common/transport/utils.py @@ -0,0 +1,141 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import logging +import pickle # nosec B403: Safe usage for internal serialization only +from multiprocessing import Event, Queue +from typing import Any + +import torch + +from lerobot.common.transport import services_pb2 +from lerobot.common.utils.transition import Transition + +CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB + + +def bytes_buffer_size(buffer: io.BytesIO) -> int: + buffer.seek(0, io.SEEK_END) + result = buffer.tell() + buffer.seek(0) + return result + + +def send_bytes_in_chunks(buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True): + buffer = io.BytesIO(buffer) + size_in_bytes = bytes_buffer_size(buffer) + + sent_bytes = 0 + + logging_method = logging.info if not silent else logging.debug + + logging_method(f"{log_prefix} Buffer size {size_in_bytes / 1024 / 1024} MB with") + + while sent_bytes < size_in_bytes: + transfer_state = services_pb2.TransferState.TRANSFER_MIDDLE + + if sent_bytes + CHUNK_SIZE >= size_in_bytes: + transfer_state = services_pb2.TransferState.TRANSFER_END + elif sent_bytes == 0: + transfer_state = services_pb2.TransferState.TRANSFER_BEGIN + + size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes) + chunk = buffer.read(size_to_read) + + yield message_class(transfer_state=transfer_state, data=chunk) + sent_bytes += size_to_read + logging_method(f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}") + + logging_method(f"{log_prefix} Published {sent_bytes / 1024 / 1024} MB") + + +def receive_bytes_in_chunks(iterator, queue: Queue, shutdown_event: Event, log_prefix: str = ""): # type: ignore + bytes_buffer = io.BytesIO() + step = 0 + + logging.info(f"{log_prefix} Starting receiver") + for item in iterator: + logging.debug(f"{log_prefix} Received item") + if shutdown_event.is_set(): + logging.info(f"{log_prefix} Shutting down receiver") + return + + if item.transfer_state == services_pb2.TransferState.TRANSFER_BEGIN: + bytes_buffer.seek(0) + bytes_buffer.truncate(0) + bytes_buffer.write(item.data) + logging.debug(f"{log_prefix} Received data at step 0") + step = 0 + elif item.transfer_state == services_pb2.TransferState.TRANSFER_MIDDLE: + bytes_buffer.write(item.data) + step += 1 + logging.debug(f"{log_prefix} Received data at step {step}") + elif item.transfer_state == services_pb2.TransferState.TRANSFER_END: + bytes_buffer.write(item.data) + logging.debug(f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}") + + queue.put(bytes_buffer.getvalue()) + + bytes_buffer.seek(0) + bytes_buffer.truncate(0) + step = 0 + + logging.debug(f"{log_prefix} Queue updated") + else: + logging.warning(f"{log_prefix} Received unknown transfer state {item.transfer_state}") + raise ValueError(f"Received unknown transfer state {item.transfer_state}") + + +def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes: + """Convert model state dict to flat array for transmission""" + buffer = io.BytesIO() + + torch.save(state_dict, buffer) + + return buffer.getvalue() + + +def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]: + buffer = io.BytesIO(buffer) + buffer.seek(0) + return torch.load(buffer, weights_only=True) + + +def python_object_to_bytes(python_object: Any) -> bytes: + return pickle.dumps(python_object) + + +def bytes_to_python_object(buffer: bytes) -> Any: + buffer = io.BytesIO(buffer) + buffer.seek(0) + obj = pickle.load(buffer) # nosec B301: Safe usage of pickle.load + # Add validation checks here + return obj + + +def bytes_to_transitions(buffer: bytes) -> list[Transition]: + buffer = io.BytesIO(buffer) + buffer.seek(0) + transitions = torch.load(buffer, weights_only=True) + return transitions + + +def transitions_to_bytes(transitions: list[Transition]) -> bytes: + buffer = io.BytesIO() + torch.save(transitions, buffer) + return buffer.getvalue() diff --git a/lerobot/common/utils/buffer.py b/lerobot/common/utils/buffer.py new file mode 100644 index 000000000..9ae231ad9 --- /dev/null +++ b/lerobot/common/utils/buffer.py @@ -0,0 +1,841 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from contextlib import suppress +from typing import Callable, Sequence, TypedDict + +import torch +import torch.nn.functional as F # noqa: N812 +from tqdm import tqdm + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.utils.transition import Transition + + +class BatchTransition(TypedDict): + state: dict[str, torch.Tensor] + action: torch.Tensor + reward: torch.Tensor + next_state: dict[str, torch.Tensor] + done: torch.Tensor + truncated: torch.Tensor + complementary_info: dict[str, torch.Tensor | float | int] | None = None + + +def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor: + """ + Perform a per-image random crop over a batch of images in a vectorized way. + (Same as shown previously.) + """ + B, C, H, W = images.shape # noqa: N806 + crop_h, crop_w = output_size + + if crop_h > H or crop_w > W: + raise ValueError( + f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})." + ) + + tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device) + lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device) + + rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1) + cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1) + + rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w) + cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w) + + images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C) + + # Gather pixels + cropped_hwcn = images_hwcn[torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, :] + # cropped_hwcn => (B, crop_h, crop_w, C) + + cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w) + return cropped + + +def random_shift(images: torch.Tensor, pad: int = 4): + """Vectorized random shift, imgs: (B,C,H,W), pad: #pixels""" + _, _, h, w = images.shape + images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate") + return random_crop_vectorized(images=images, output_size=(h, w)) + + +class ReplayBuffer: + def __init__( + self, + capacity: int, + device: str = "cuda:0", + state_keys: Sequence[str] | None = None, + image_augmentation_function: Callable | None = None, + use_drq: bool = True, + storage_device: str = "cpu", + optimize_memory: bool = False, + ): + """ + Replay buffer for storing transitions. + It will allocate tensors on the specified device, when the first transition is added. + NOTE: If you encounter memory issues, you can try to use the `optimize_memory` flag to save memory or + and use the `storage_device` flag to store the buffer on a different device. + Args: + capacity (int): Maximum number of transitions to store in the buffer. + device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu"). + state_keys (List[str]): The list of keys that appear in `state` and `next_state`. + image_augmentation_function (Optional[Callable]): A function that takes a batch of images + and returns a batch of augmented images. If None, a default augmentation function is used. + use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer. + storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored. + Using "cpu" can help save GPU memory. + optimize_memory (bool): If True, optimizes memory by not storing duplicate next_states when + they can be derived from states. This is useful for large datasets where next_state[i] = state[i+1]. + """ + if capacity <= 0: + raise ValueError("Capacity must be greater than 0.") + + self.capacity = capacity + self.device = device + self.storage_device = storage_device + self.position = 0 + self.size = 0 + self.initialized = False + self.optimize_memory = optimize_memory + + # Track episode boundaries for memory optimization + self.episode_ends = torch.zeros(capacity, dtype=torch.bool, device=storage_device) + + # If no state_keys provided, default to an empty list + self.state_keys = state_keys if state_keys is not None else [] + + self.image_augmentation_function = image_augmentation_function + + if image_augmentation_function is None: + base_function = functools.partial(random_shift, pad=4) + self.image_augmentation_function = torch.compile(base_function) + self.use_drq = use_drq + + def _initialize_storage( + self, + state: dict[str, torch.Tensor], + action: torch.Tensor, + complementary_info: dict[str, torch.Tensor] | None = None, + ): + """Initialize the storage tensors based on the first transition.""" + # Determine shapes from the first transition + state_shapes = {key: val.squeeze(0).shape for key, val in state.items()} + action_shape = action.squeeze(0).shape + + # Pre-allocate tensors for storage + self.states = { + key: torch.empty((self.capacity, *shape), device=self.storage_device) + for key, shape in state_shapes.items() + } + self.actions = torch.empty((self.capacity, *action_shape), device=self.storage_device) + self.rewards = torch.empty((self.capacity,), device=self.storage_device) + + if not self.optimize_memory: + # Standard approach: store states and next_states separately + self.next_states = { + key: torch.empty((self.capacity, *shape), device=self.storage_device) + for key, shape in state_shapes.items() + } + else: + # Memory-optimized approach: don't allocate next_states buffer + # Just create a reference to states for consistent API + self.next_states = self.states # Just a reference for API consistency + + self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) + self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) + + # Initialize storage for complementary_info + self.has_complementary_info = complementary_info is not None + self.complementary_info_keys = [] + self.complementary_info = {} + + if self.has_complementary_info: + self.complementary_info_keys = list(complementary_info.keys()) + # Pre-allocate tensors for each key in complementary_info + for key, value in complementary_info.items(): + if isinstance(value, torch.Tensor): + value_shape = value.squeeze(0).shape + self.complementary_info[key] = torch.empty( + (self.capacity, *value_shape), device=self.storage_device + ) + elif isinstance(value, (int, float)): + # Handle scalar values similar to reward + self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device) + else: + raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]") + + self.initialized = True + + def __len__(self): + return self.size + + def add( + self, + state: dict[str, torch.Tensor], + action: torch.Tensor, + reward: float, + next_state: dict[str, torch.Tensor], + done: bool, + truncated: bool, + complementary_info: dict[str, torch.Tensor] | None = None, + ): + """Saves a transition, ensuring tensors are stored on the designated storage device.""" + # Initialize storage if this is the first transition + if not self.initialized: + self._initialize_storage(state=state, action=action, complementary_info=complementary_info) + + # Store the transition in pre-allocated tensors + for key in self.states: + self.states[key][self.position].copy_(state[key].squeeze(dim=0)) + + if not self.optimize_memory: + # Only store next_states if not optimizing memory + self.next_states[key][self.position].copy_(next_state[key].squeeze(dim=0)) + + self.actions[self.position].copy_(action.squeeze(dim=0)) + self.rewards[self.position] = reward + self.dones[self.position] = done + self.truncateds[self.position] = truncated + + # Handle complementary_info if provided and storage is initialized + if complementary_info is not None and self.has_complementary_info: + # Store the complementary_info + for key in self.complementary_info_keys: + if key in complementary_info: + value = complementary_info[key] + if isinstance(value, torch.Tensor): + self.complementary_info[key][self.position].copy_(value.squeeze(dim=0)) + elif isinstance(value, (int, float)): + self.complementary_info[key][self.position] = value + + self.position = (self.position + 1) % self.capacity + self.size = min(self.size + 1, self.capacity) + + def sample(self, batch_size: int) -> BatchTransition: + """Sample a random batch of transitions and collate them into batched tensors.""" + if not self.initialized: + raise RuntimeError("Cannot sample from an empty buffer. Add transitions first.") + + batch_size = min(batch_size, self.size) + high = max(0, self.size - 1) if self.optimize_memory and self.size < self.capacity else self.size + + # Random indices for sampling - create on the same device as storage + idx = torch.randint(low=0, high=high, size=(batch_size,), device=self.storage_device) + + # Identify image keys that need augmentation + image_keys = [k for k in self.states if k.startswith("observation.image")] if self.use_drq else [] + + # Create batched state and next_state + batch_state = {} + batch_next_state = {} + + # First pass: load all state tensors to target device + for key in self.states: + batch_state[key] = self.states[key][idx].to(self.device) + + if not self.optimize_memory: + # Standard approach - load next_states directly + batch_next_state[key] = self.next_states[key][idx].to(self.device) + else: + # Memory-optimized approach - get next_state from the next index + next_idx = (idx + 1) % self.capacity + batch_next_state[key] = self.states[key][next_idx].to(self.device) + + # Apply image augmentation in a batched way if needed + if self.use_drq and image_keys: + # Concatenate all images from state and next_state + all_images = [] + for key in image_keys: + all_images.append(batch_state[key]) + all_images.append(batch_next_state[key]) + + # Optimization: Batch all images and apply augmentation once + all_images_tensor = torch.cat(all_images, dim=0) + augmented_images = self.image_augmentation_function(all_images_tensor) + + # Split the augmented images back to their sources + for i, key in enumerate(image_keys): + # Calculate offsets for the current image key: + # For each key, we have 2*batch_size images (batch_size for states, batch_size for next_states) + # States start at index i*2*batch_size and take up batch_size slots + batch_state[key] = augmented_images[i * 2 * batch_size : (i * 2 + 1) * batch_size] + # Next states start after the states at index (i*2+1)*batch_size and also take up batch_size slots + batch_next_state[key] = augmented_images[(i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size] + + # Sample other tensors + batch_actions = self.actions[idx].to(self.device) + batch_rewards = self.rewards[idx].to(self.device) + batch_dones = self.dones[idx].to(self.device).float() + batch_truncateds = self.truncateds[idx].to(self.device).float() + + # Sample complementary_info if available + batch_complementary_info = None + if self.has_complementary_info: + batch_complementary_info = {} + for key in self.complementary_info_keys: + batch_complementary_info[key] = self.complementary_info[key][idx].to(self.device) + + return BatchTransition( + state=batch_state, + action=batch_actions, + reward=batch_rewards, + next_state=batch_next_state, + done=batch_dones, + truncated=batch_truncateds, + complementary_info=batch_complementary_info, + ) + + def get_iterator( + self, + batch_size: int, + async_prefetch: bool = True, + queue_size: int = 2, + ): + """ + Creates an infinite iterator that yields batches of transitions. + Will automatically restart when internal iterator is exhausted. + + Args: + batch_size (int): Size of batches to sample + async_prefetch (bool): Whether to use asynchronous prefetching with threads (default: True) + queue_size (int): Number of batches to prefetch (default: 2) + + Yields: + BatchTransition: Batched transitions + """ + while True: # Create an infinite loop + if async_prefetch: + # Get the standard iterator + iterator = self._get_async_iterator(queue_size=queue_size, batch_size=batch_size) + else: + iterator = self._get_naive_iterator(batch_size=batch_size, queue_size=queue_size) + + # Yield all items from the iterator + with suppress(StopIteration): + yield from iterator + + def _get_async_iterator(self, batch_size: int, queue_size: int = 2): + """ + Create an iterator that continuously yields prefetched batches in a + background thread. The design is intentionally simple and avoids busy + waiting / complex state management. + + Args: + batch_size (int): Size of batches to sample. + queue_size (int): Maximum number of prefetched batches to keep in + memory. + + Yields: + BatchTransition: A batch sampled from the replay buffer. + """ + import queue + import threading + + data_queue: queue.Queue = queue.Queue(maxsize=queue_size) + shutdown_event = threading.Event() + + def producer() -> None: + """Continuously put sampled batches into the queue until shutdown.""" + while not shutdown_event.is_set(): + try: + batch = self.sample(batch_size) + # The timeout ensures the thread unblocks if the queue is full + # and the shutdown event gets set meanwhile. + data_queue.put(batch, block=True, timeout=0.5) + except queue.Full: + # Queue is full – loop again (will re-check shutdown_event) + continue + except Exception: + # Surface any unexpected error and terminate the producer. + shutdown_event.set() + + producer_thread = threading.Thread(target=producer, daemon=True) + producer_thread.start() + + try: + while not shutdown_event.is_set(): + try: + yield data_queue.get(block=True) + except Exception: + # If the producer already set the shutdown flag we exit. + if shutdown_event.is_set(): + break + finally: + shutdown_event.set() + # Drain the queue quickly to help the thread exit if it's blocked on `put`. + while not data_queue.empty(): + _ = data_queue.get_nowait() + # Give the producer thread a bit of time to finish. + producer_thread.join(timeout=1.0) + + def _get_naive_iterator(self, batch_size: int, queue_size: int = 2): + """ + Creates a simple non-threaded iterator that yields batches. + + Args: + batch_size (int): Size of batches to sample + queue_size (int): Number of initial batches to prefetch + + Yields: + BatchTransition: Batch transitions + """ + import collections + + queue = collections.deque() + + def enqueue(n): + for _ in range(n): + data = self.sample(batch_size) + queue.append(data) + + enqueue(queue_size) + while queue: + yield queue.popleft() + enqueue(1) + + @classmethod + def from_lerobot_dataset( + cls, + lerobot_dataset: LeRobotDataset, + device: str = "cuda:0", + state_keys: Sequence[str] | None = None, + capacity: int | None = None, + image_augmentation_function: Callable | None = None, + use_drq: bool = True, + storage_device: str = "cpu", + optimize_memory: bool = False, + ) -> "ReplayBuffer": + """ + Convert a LeRobotDataset into a ReplayBuffer. + + Args: + lerobot_dataset (LeRobotDataset): The dataset to convert. + device (str): The device for sampling tensors. Defaults to "cuda:0". + state_keys (Sequence[str] | None): The list of keys that appear in `state` and `next_state`. + capacity (int | None): Buffer capacity. If None, uses dataset length. + action_mask (Sequence[int] | None): Indices of action dimensions to keep. + image_augmentation_function (Callable | None): Function for image augmentation. + If None, uses default random shift with pad=4. + use_drq (bool): Whether to use DrQ image augmentation when sampling. + storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory. + optimize_memory (bool): If True, reduces memory usage by not duplicating state data. + + Returns: + ReplayBuffer: The replay buffer with dataset transitions. + """ + if capacity is None: + capacity = len(lerobot_dataset) + + if capacity < len(lerobot_dataset): + raise ValueError( + "The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset." + ) + + # Create replay buffer with image augmentation and DrQ settings + replay_buffer = cls( + capacity=capacity, + device=device, + state_keys=state_keys, + image_augmentation_function=image_augmentation_function, + use_drq=use_drq, + storage_device=storage_device, + optimize_memory=optimize_memory, + ) + + # Convert dataset to transitions + list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys) + + # Initialize the buffer with the first transition to set up storage tensors + if list_transition: + first_transition = list_transition[0] + first_state = {k: v.to(device) for k, v in first_transition["state"].items()} + first_action = first_transition["action"].to(device) + + # Get complementary info if available + first_complementary_info = None + if ( + "complementary_info" in first_transition + and first_transition["complementary_info"] is not None + ): + first_complementary_info = { + k: v.to(device) for k, v in first_transition["complementary_info"].items() + } + + replay_buffer._initialize_storage( + state=first_state, action=first_action, complementary_info=first_complementary_info + ) + + # Fill the buffer with all transitions + for data in list_transition: + for k, v in data.items(): + if isinstance(v, dict): + for key, tensor in v.items(): + v[key] = tensor.to(storage_device) + elif isinstance(v, torch.Tensor): + data[k] = v.to(storage_device) + + action = data["action"] + + replay_buffer.add( + state=data["state"], + action=action, + reward=data["reward"], + next_state=data["next_state"], + done=data["done"], + truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset + complementary_info=data.get("complementary_info", None), + ) + + return replay_buffer + + def to_lerobot_dataset( + self, + repo_id: str, + fps=1, + root=None, + task_name="from_replay_buffer", + ) -> LeRobotDataset: + """ + Converts all transitions in this ReplayBuffer into a single LeRobotDataset object. + """ + if self.size == 0: + raise ValueError("The replay buffer is empty. Cannot convert to a dataset.") + + # Create features dictionary for the dataset + features = { + "index": {"dtype": "int64", "shape": [1]}, # global index across episodes + "episode_index": {"dtype": "int64", "shape": [1]}, # which episode + "frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode + "timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy + "task_index": {"dtype": "int64", "shape": [1]}, + } + + # Add "action" + sample_action = self.actions[0] + act_info = guess_feature_info(t=sample_action, name="action") + features["action"] = act_info + + # Add "reward" and "done" + features["next.reward"] = {"dtype": "float32", "shape": (1,)} + features["next.done"] = {"dtype": "bool", "shape": (1,)} + + # Add state keys + for key in self.states: + sample_val = self.states[key][0] + f_info = guess_feature_info(t=sample_val, name=key) + features[key] = f_info + + # Add complementary_info keys if available + if self.has_complementary_info: + for key in self.complementary_info_keys: + sample_val = self.complementary_info[key][0] + if isinstance(sample_val, torch.Tensor) and sample_val.ndim == 0: + sample_val = sample_val.unsqueeze(0) + f_info = guess_feature_info(t=sample_val, name=f"complementary_info.{key}") + features[f"complementary_info.{key}"] = f_info + + # Create an empty LeRobotDataset + lerobot_dataset = LeRobotDataset.create( + repo_id=repo_id, + fps=fps, + root=root, + robot_type=None, + features=features, + use_videos=True, + ) + + # Start writing images if needed + lerobot_dataset.start_image_writer(num_processes=0, num_threads=3) + + # Convert transitions into episodes and frames + episode_index = 0 + lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index=episode_index) + + frame_idx_in_episode = 0 + for idx in range(self.size): + actual_idx = (self.position - self.size + idx) % self.capacity + + frame_dict = {} + + # Fill the data for state keys + for key in self.states: + frame_dict[key] = self.states[key][actual_idx].cpu() + + # Fill action, reward, done + frame_dict["action"] = self.actions[actual_idx].cpu() + frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() + frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() + + # Add complementary_info if available + if self.has_complementary_info: + for key in self.complementary_info_keys: + val = self.complementary_info[key][actual_idx] + # Convert tensors to CPU + if isinstance(val, torch.Tensor): + if val.ndim == 0: + val = val.unsqueeze(0) + frame_dict[f"complementary_info.{key}"] = val.cpu() + # Non-tensor values can be used directly + else: + frame_dict[f"complementary_info.{key}"] = val + + # Add to the dataset's buffer + lerobot_dataset.add_frame(frame_dict, task=task_name) + + # Move to next frame + frame_idx_in_episode += 1 + + # If we reached an episode boundary, call save_episode, reset counters + if self.dones[actual_idx] or self.truncateds[actual_idx]: + lerobot_dataset.save_episode() + episode_index += 1 + frame_idx_in_episode = 0 + lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer( + episode_index=episode_index + ) + + # Save any remaining frames in the buffer + if lerobot_dataset.episode_buffer["size"] > 0: + lerobot_dataset.save_episode() + + lerobot_dataset.stop_image_writer() + + return lerobot_dataset + + @staticmethod + def _lerobotdataset_to_transitions( + dataset: LeRobotDataset, + state_keys: Sequence[str] | None = None, + ) -> list[Transition]: + """ + Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions. + + Args: + dataset (LeRobotDataset): + The dataset to convert. Each item in the dataset is expected to have + at least the following keys: + { + "action": ... + "next.reward": ... + "next.done": ... + "episode_index": ... + } + plus whatever your 'state_keys' specify. + + state_keys (Sequence[str] | None): + The dataset keys to include in 'state' and 'next_state'. Their names + will be kept as-is in the output transitions. E.g. + ["observation.state", "observation.environment_state"]. + If None, you must handle or define default keys. + + Returns: + transitions (List[Transition]): + A list of Transition dictionaries with the same length as `dataset`. + """ + if state_keys is None: + raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.") + + transitions = [] + num_frames = len(dataset) + + # Check if the dataset has "next.done" key + sample = dataset[0] + has_done_key = "next.done" in sample + + # Check for complementary_info keys + complementary_info_keys = [key for key in sample if key.startswith("complementary_info.")] + has_complementary_info = len(complementary_info_keys) > 0 + + # If not, we need to infer it from episode boundaries + if not has_done_key: + print("'next.done' key not found in dataset. Inferring from episode boundaries...") + + for i in tqdm(range(num_frames)): + current_sample = dataset[i] + + # ----- 1) Current state ----- + current_state: dict[str, torch.Tensor] = {} + for key in state_keys: + val = current_sample[key] + current_state[key] = val.unsqueeze(0) # Add batch dimension + + # ----- 2) Action ----- + action = current_sample["action"].unsqueeze(0) # Add batch dimension + + # ----- 3) Reward and done ----- + reward = float(current_sample["next.reward"].item()) # ensure float + + # Determine done flag - use next.done if available, otherwise infer from episode boundaries + if has_done_key: + done = bool(current_sample["next.done"].item()) # ensure bool + else: + # If this is the last frame or if next frame is in a different episode, mark as done + done = False + if i == num_frames - 1: + done = True + elif i < num_frames - 1: + next_sample = dataset[i + 1] + if next_sample["episode_index"] != current_sample["episode_index"]: + done = True + + # TODO: (azouitine) Handle truncation (using the same value as done for now) + truncated = done + + # ----- 4) Next state ----- + # If not done and the next sample is in the same episode, we pull the next sample's state. + # Otherwise (done=True or next sample crosses to a new episode), next_state = current_state. + next_state = current_state # default + if not done and (i < num_frames - 1): + next_sample = dataset[i + 1] + if next_sample["episode_index"] == current_sample["episode_index"]: + # Build next_state from the same keys + next_state_data: dict[str, torch.Tensor] = {} + for key in state_keys: + val = next_sample[key] + next_state_data[key] = val.unsqueeze(0) # Add batch dimension + next_state = next_state_data + + # ----- 5) Complementary info (if available) ----- + complementary_info = None + if has_complementary_info: + complementary_info = {} + for key in complementary_info_keys: + # Strip the "complementary_info." prefix to get the actual key + clean_key = key[len("complementary_info.") :] + val = current_sample[key] + # Handle tensor and non-tensor values differently + if isinstance(val, torch.Tensor): + complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension + else: + # TODO: (azouitine) Check if it's necessary to convert to tensor + # For non-tensor values, use directly + complementary_info[clean_key] = val + + # ----- Construct the Transition ----- + transition = Transition( + state=current_state, + action=action, + reward=reward, + next_state=next_state, + done=done, + truncated=truncated, + complementary_info=complementary_info, + ) + transitions.append(transition) + + return transitions + + +# Utility function to guess shapes/dtypes from a tensor +def guess_feature_info(t, name: str): + """ + Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value. + If it looks like a 3D (C,H,W) shape, we might consider it an 'image'. + Otherwise default to appropriate dtype for numeric. + """ + + shape = tuple(t.shape) + # Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image' + if len(shape) == 3 and shape[0] in [1, 3]: + return { + "dtype": "image", + "shape": shape, + } + else: + # Otherwise treat as numeric + return { + "dtype": "float32", + "shape": shape, + } + + +def concatenate_batch_transitions( + left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition +) -> BatchTransition: + """ + Concatenates two BatchTransition objects into one. + + This function merges the right BatchTransition into the left one by concatenating + all corresponding tensors along dimension 0. The operation modifies the left_batch_transitions + in place and also returns it. + + Args: + left_batch_transitions (BatchTransition): The first batch to concatenate and the one + that will be modified in place. + right_batch_transition (BatchTransition): The second batch to append to the first one. + + Returns: + BatchTransition: The concatenated batch (same object as left_batch_transitions). + + Warning: + This function modifies the left_batch_transitions object in place. + """ + # Concatenate state fields + left_batch_transitions["state"] = { + key: torch.cat( + [left_batch_transitions["state"][key], right_batch_transition["state"][key]], + dim=0, + ) + for key in left_batch_transitions["state"] + } + + # Concatenate basic fields + left_batch_transitions["action"] = torch.cat( + [left_batch_transitions["action"], right_batch_transition["action"]], dim=0 + ) + left_batch_transitions["reward"] = torch.cat( + [left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0 + ) + + # Concatenate next_state fields + left_batch_transitions["next_state"] = { + key: torch.cat( + [left_batch_transitions["next_state"][key], right_batch_transition["next_state"][key]], + dim=0, + ) + for key in left_batch_transitions["next_state"] + } + + # Concatenate done and truncated fields + left_batch_transitions["done"] = torch.cat( + [left_batch_transitions["done"], right_batch_transition["done"]], dim=0 + ) + left_batch_transitions["truncated"] = torch.cat( + [left_batch_transitions["truncated"], right_batch_transition["truncated"]], + dim=0, + ) + + # Handle complementary_info + left_info = left_batch_transitions.get("complementary_info") + right_info = right_batch_transition.get("complementary_info") + + # Only process if right_info exists + if right_info is not None: + # Initialize left complementary_info if needed + if left_info is None: + left_batch_transitions["complementary_info"] = right_info + else: + # Concatenate each field + for key in right_info: + if key in left_info: + left_info[key] = torch.cat([left_info[key], right_info[key]], dim=0) + else: + left_info[key] = right_info[key] + + return left_batch_transitions diff --git a/lerobot/common/utils/control_utils.py b/lerobot/common/utils/control_utils.py new file mode 100644 index 000000000..e69de29bb diff --git a/lerobot/common/utils/encoding_utils.py b/lerobot/common/utils/encoding_utils.py new file mode 100644 index 000000000..195cdbe2c --- /dev/null +++ b/lerobot/common/utils/encoding_utils.py @@ -0,0 +1,67 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def encode_sign_magnitude(value: int, sign_bit_index: int): + """ + https://en.wikipedia.org/wiki/Signed_number_representations#Sign%E2%80%93magnitude + """ + max_magnitude = (1 << sign_bit_index) - 1 + magnitude = abs(value) + if magnitude > max_magnitude: + raise ValueError(f"Magnitude {magnitude} exceeds {max_magnitude} (max for {sign_bit_index=})") + + direction_bit = 1 if value < 0 else 0 + return (direction_bit << sign_bit_index) | magnitude + + +def decode_sign_magnitude(encoded_value: int, sign_bit_index: int): + """ + https://en.wikipedia.org/wiki/Signed_number_representations#Sign%E2%80%93magnitude + """ + direction_bit = (encoded_value >> sign_bit_index) & 1 + magnitude_mask = (1 << sign_bit_index) - 1 + magnitude = encoded_value & magnitude_mask + return -magnitude if direction_bit else magnitude + + +def encode_twos_complement(value: int, n_bytes: int): + """ + https://en.wikipedia.org/wiki/Signed_number_representations#Two%27s_complement + """ + + bit_width = n_bytes * 8 + min_val = -(1 << (bit_width - 1)) + max_val = (1 << (bit_width - 1)) - 1 + + if not (min_val <= value <= max_val): + raise ValueError( + f"Value {value} out of range for {n_bytes}-byte two's complement: [{min_val}, {max_val}]" + ) + + if value >= 0: + return value + + return (1 << bit_width) + value + + +def decode_twos_complement(value: int, n_bytes: int) -> int: + """ + https://en.wikipedia.org/wiki/Signed_number_representations#Two%27s_complement + """ + bits = n_bytes * 8 + sign_bit = 1 << (bits - 1) + if value & sign_bit: + value -= 1 << bits + return value diff --git a/lerobot/common/utils/import_utils.py b/lerobot/common/utils/import_utils.py index cd5f82450..5c29b5a84 100644 --- a/lerobot/common/utils/import_utils.py +++ b/lerobot/common/utils/import_utils.py @@ -28,6 +28,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b try: # Primary method to get the package version package_version = importlib.metadata.version(pkg_name) + except importlib.metadata.PackageNotFoundError: # Fallback method: Only for "torch" and versions containing "dev" if pkg_name == "torch": @@ -43,6 +44,9 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b except ImportError: # If the package can't be imported, it's not available package_exists = False + elif pkg_name == "grpc": + package = importlib.import_module(pkg_name) + package_version = getattr(package, "__version__", "N/A") else: # For packages other than "torch", don't attempt the fallback and set as not available package_exists = False diff --git a/lerobot/common/utils/process.py b/lerobot/common/utils/process.py new file mode 100644 index 000000000..72438b6f9 --- /dev/null +++ b/lerobot/common/utils/process.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import signal +import sys + + +class ProcessSignalHandler: + """Utility class to attach graceful shutdown signal handlers. + + The class exposes a shutdown_event attribute that is set when a shutdown + signal is received. A counter tracks how many shutdown signals have been + caught. On the second signal the process exits with status 1. + """ + + _SUPPORTED_SIGNALS = ("SIGINT", "SIGTERM", "SIGHUP", "SIGQUIT") + + def __init__(self, use_threads: bool, display_pid: bool = False): + # TODO: Check if we can use Event from threading since Event from + # multiprocessing is the a clone of threading.Event. + # https://docs.python.org/3/library/multiprocessing.html#multiprocessing.Event + if use_threads: + from threading import Event + else: + from multiprocessing import Event + + self.shutdown_event = Event() + self._counter: int = 0 + self._display_pid = display_pid + + self._register_handlers() + + @property + def counter(self) -> int: # pragma: no cover – simple accessor + """Number of shutdown signals that have been intercepted.""" + return self._counter + + def _register_handlers(self): + """Attach the internal _signal_handler to a subset of POSIX signals.""" + + def _signal_handler(signum, frame): + pid_str = "" + if self._display_pid: + pid_str = f"[PID: {os.getpid()}]" + logging.info(f"{pid_str} Shutdown signal {signum} received. Cleaning up…") + self.shutdown_event.set() + self._counter += 1 + + # On a second Ctrl-C (or any supported signal) force the exit to + # mimic the previous behaviour while giving the caller one chance to + # shutdown gracefully. + # TODO: Investigate if we need it later + if self._counter > 1: + logging.info("Force shutdown") + sys.exit(1) + + for sig_name in self._SUPPORTED_SIGNALS: + sig = getattr(signal, sig_name, None) + if sig is None: + # The signal is not available on this platform (Windows for + # instance does not provide SIGHUP, SIGQUIT…). Skip it. + continue + try: + signal.signal(sig, _signal_handler) + except (ValueError, OSError): # pragma: no cover – unlikely but safe + # Signal not supported or we are in a non-main thread. + continue diff --git a/lerobot/common/utils/queue.py b/lerobot/common/utils/queue.py new file mode 100644 index 000000000..ceb30e2bf --- /dev/null +++ b/lerobot/common/utils/queue.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from queue import Empty +from typing import Any + +from torch.multiprocessing import Queue + + +def get_last_item_from_queue(queue: Queue, block=True, timeout: float = 0.1) -> Any: + if block: + try: + item = queue.get(timeout=timeout) + except Empty: + return None + else: + item = None + + # Drain queue and keep only the most recent parameters + try: + while True: + item = queue.get_nowait() + except Empty: + pass + + return item diff --git a/lerobot/common/robot_devices/utils.py b/lerobot/common/utils/robot_utils.py similarity index 70% rename from lerobot/common/robot_devices/utils.py rename to lerobot/common/utils/robot_utils.py index 837c9d2eb..e6c0cfe6d 100644 --- a/lerobot/common/robot_devices/utils.py +++ b/lerobot/common/utils/robot_utils.py @@ -42,24 +42,3 @@ def safe_disconnect(func): raise e return wrapper - - -class RobotDeviceNotConnectedError(Exception): - """Exception raised when the robot device is not connected.""" - - def __init__( - self, message="This robot device is not connected. Try calling `robot_device.connect()` first." - ): - self.message = message - super().__init__(self.message) - - -class RobotDeviceAlreadyConnectedError(Exception): - """Exception raised when the robot device is already connected.""" - - def __init__( - self, - message="This robot device is already connected. Try not calling `robot_device.connect()` twice.", - ): - self.message = message - super().__init__(self.message) diff --git a/lerobot/common/utils/transition.py b/lerobot/common/utils/transition.py new file mode 100644 index 000000000..db413c388 --- /dev/null +++ b/lerobot/common/utils/transition.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TypedDict + +import torch + + +class Transition(TypedDict): + state: dict[str, torch.Tensor] + action: torch.Tensor + reward: float + next_state: dict[str, torch.Tensor] + done: bool + truncated: bool + complementary_info: dict[str, torch.Tensor | float | int] | None = None + + +def move_transition_to_device(transition: Transition, device: str = "cpu") -> Transition: + device = torch.device(device) + non_blocking = device.type == "cuda" + + # Move state tensors to device + transition["state"] = { + key: val.to(device, non_blocking=non_blocking) for key, val in transition["state"].items() + } + + # Move action to device + transition["action"] = transition["action"].to(device, non_blocking=non_blocking) + + # Move reward and done if they are tensors + if isinstance(transition["reward"], torch.Tensor): + transition["reward"] = transition["reward"].to(device, non_blocking=non_blocking) + + if isinstance(transition["done"], torch.Tensor): + transition["done"] = transition["done"].to(device, non_blocking=non_blocking) + + if isinstance(transition["truncated"], torch.Tensor): + transition["truncated"] = transition["truncated"].to(device, non_blocking=non_blocking) + + # Move next_state tensors to device + transition["next_state"] = { + key: val.to(device, non_blocking=non_blocking) for key, val in transition["next_state"].items() + } + + # Move complementary_info tensors if present + if transition.get("complementary_info") is not None: + for key, val in transition["complementary_info"].items(): + if isinstance(val, torch.Tensor): + transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking) + elif isinstance(val, (int, float, bool)): + transition["complementary_info"][key] = torch.tensor(val, device=device) + else: + raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]") + return transition + + +def move_state_dict_to_device(state_dict, device="cpu"): + """ + Recursively move all tensors in a (potentially) nested + dict/list/tuple structure to the CPU. + """ + if isinstance(state_dict, torch.Tensor): + return state_dict.to(device) + elif isinstance(state_dict, dict): + return {k: move_state_dict_to_device(v, device=device) for k, v in state_dict.items()} + elif isinstance(state_dict, list): + return [move_state_dict_to_device(v, device=device) for v in state_dict] + elif isinstance(state_dict, tuple): + return tuple(move_state_dict_to_device(v, device=device) for v in state_dict) + else: + return state_dict diff --git a/lerobot/common/utils/visualization_utils.py b/lerobot/common/utils/visualization_utils.py new file mode 100644 index 000000000..dfffece5f --- /dev/null +++ b/lerobot/common/utils/visualization_utils.py @@ -0,0 +1,26 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import rerun as rr + + +def _init_rerun(session_name: str = "lerobot_control_loop") -> None: + """Initializes the Rerun SDK for visualizing the control loop.""" + batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000") + os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size + rr.init(session_name) + memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%") + rr.spawn(memory_limit=memory_limit) diff --git a/lerobot/common/utils/wandb_utils.py b/lerobot/common/utils/wandb_utils.py index 3fe241d41..ac4d22343 100644 --- a/lerobot/common/utils/wandb_utils.py +++ b/lerobot/common/utils/wandb_utils.py @@ -30,9 +30,10 @@ def cfg_to_group(cfg: TrainPipelineConfig, return_list: bool = False) -> list[st """Return a group name for logging. Optionally returns group name as list.""" lst = [ f"policy:{cfg.policy.type}", - f"dataset:{cfg.dataset.repo_id}", f"seed:{cfg.seed}", ] + if cfg.dataset is not None: + lst.append(f"dataset:{cfg.dataset.repo_id}") if cfg.env is not None: lst.append(f"env:{cfg.env.type}") return lst if return_list else "-".join(lst) @@ -92,6 +93,12 @@ class WandBLogger: resume="must" if cfg.resume else None, mode=self.cfg.mode if self.cfg.mode in ["online", "offline", "disabled"] else "online", ) + run_id = wandb.run.id + # NOTE: We will override the cfg.wandb.run_id with the wandb run id. + # This is because we want to be able to resume the run from the wandb run id. + cfg.wandb.run_id = run_id + # Handle custom step key for rl asynchronous training. + self._wandb_custom_step_key: set[str] | None = None print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") self._wandb = wandb @@ -108,17 +115,45 @@ class WandBLogger: artifact.add_file(checkpoint_dir / PRETRAINED_MODEL_DIR / SAFETENSORS_SINGLE_FILE) self._wandb.log_artifact(artifact) - def log_dict(self, d: dict, step: int, mode: str = "train"): + def log_dict( + self, d: dict, step: int | None = None, mode: str = "train", custom_step_key: str | None = None + ): if mode not in {"train", "eval"}: raise ValueError(mode) + if step is None and custom_step_key is None: + raise ValueError("Either step or custom_step_key must be provided.") + + # NOTE: This is not simple. Wandb step must always monotonically increase and it + # increases with each wandb.log call, but in the case of asynchronous RL for example, + # multiple time steps is possible. For example, the interaction step with the environment, + # the training step, the evaluation step, etc. So we need to define a custom step key + # to log the correct step for each metric. + if custom_step_key is not None: + if self._wandb_custom_step_key is None: + self._wandb_custom_step_key = set() + new_custom_key = f"{mode}/{custom_step_key}" + if new_custom_key not in self._wandb_custom_step_key: + self._wandb_custom_step_key.add(new_custom_key) + self._wandb.define_metric(new_custom_key, hidden=True) for k, v in d.items(): if not isinstance(v, (int, float, str)): logging.warning( - f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.' + f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.' ) continue - self._wandb.log({f"{mode}/{k}": v}, step=step) + + # Do not log the custom step key itself. + if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key: + continue + + if custom_step_key is not None: + value_custom_step = d[custom_step_key] + data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step} + self._wandb.log(data) + continue + + self._wandb.log(data={f"{mode}/{k}": v}, step=step) def log_video(self, video_path: str, step: int, mode: str = "train"): if mode not in {"train", "eval"}: diff --git a/lerobot/configs/train.py b/lerobot/configs/train.py index 7a787b83e..377fb8a9b 100644 --- a/lerobot/configs/train.py +++ b/lerobot/configs/train.py @@ -116,6 +116,11 @@ class TrainPipelineConfig(HubMixin): self.optimizer = self.policy.get_optimizer_preset() self.scheduler = self.policy.get_scheduler_preset() + if self.policy.push_to_hub and not self.policy.repo_id: + raise ValueError( + "'policy.repo_id' argument missing. Please specify it to push the model to the hub." + ) + @classmethod def __get_path_fields__(cls) -> list[str]: """This enables the parser to load config from the policy using `--policy.path=local/dir`""" @@ -170,6 +175,10 @@ class TrainPipelineConfig(HubMixin): ) from e cli_args = kwargs.pop("cli_args", []) - cfg = draccus.parse(cls, config_file, args=cli_args) + with draccus.config_type("json"): + return draccus.parse(cls, config_file, args=cli_args) - return cfg + +@dataclass(kw_only=True) +class TrainRLServerPipelineConfig(TrainPipelineConfig): + dataset: DatasetConfig | None = None # NOTE: In RL, we don't need an offline dataset diff --git a/lerobot/configs/types.py b/lerobot/configs/types.py index 6b3d92e80..6040ff70b 100644 --- a/lerobot/configs/types.py +++ b/lerobot/configs/types.py @@ -23,6 +23,7 @@ class FeatureType(str, Enum): VISUAL = "VISUAL" ENV = "ENV" ACTION = "ACTION" + REWARD = "REWARD" class NormalizationMode(str, Enum): diff --git a/lerobot/find_cameras.py b/lerobot/find_cameras.py new file mode 100644 index 000000000..34f4865b1 --- /dev/null +++ b/lerobot/find_cameras.py @@ -0,0 +1,315 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Helper to find the camera devices available in your system. + +Example: + +```shell +python -m lerobot.find_cameras +``` +""" + +# NOTE(Steven): RealSense can also be identified/opened as OpenCV cameras. If you know the camera is a RealSense, use the `lerobot.find_cameras realsense` flag to avoid confusion. +# NOTE(Steven): macOS cameras sometimes report different FPS at init time, not an issue here as we don't specify FPS when opening the cameras, but the information displayed might not be truthful. + +import argparse +import concurrent.futures +import logging +import time +from pathlib import Path +from typing import Any, Dict, List + +import numpy as np +from PIL import Image + +from lerobot.common.cameras.configs import ColorMode +from lerobot.common.cameras.opencv.camera_opencv import OpenCVCamera +from lerobot.common.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.common.cameras.realsense.camera_realsense import RealSenseCamera +from lerobot.common.cameras.realsense.configuration_realsense import RealSenseCameraConfig + +logger = logging.getLogger(__name__) + + +def find_all_opencv_cameras() -> List[Dict[str, Any]]: + """ + Finds all available OpenCV cameras plugged into the system. + + Returns: + A list of all available OpenCV cameras with their metadata. + """ + all_opencv_cameras_info: List[Dict[str, Any]] = [] + logger.info("Searching for OpenCV cameras...") + try: + opencv_cameras = OpenCVCamera.find_cameras() + for cam_info in opencv_cameras: + all_opencv_cameras_info.append(cam_info) + logger.info(f"Found {len(opencv_cameras)} OpenCV cameras.") + except Exception as e: + logger.error(f"Error finding OpenCV cameras: {e}") + + return all_opencv_cameras_info + + +def find_all_realsense_cameras() -> List[Dict[str, Any]]: + """ + Finds all available RealSense cameras plugged into the system. + + Returns: + A list of all available RealSense cameras with their metadata. + """ + all_realsense_cameras_info: List[Dict[str, Any]] = [] + logger.info("Searching for RealSense cameras...") + try: + realsense_cameras = RealSenseCamera.find_cameras() + for cam_info in realsense_cameras: + all_realsense_cameras_info.append(cam_info) + logger.info(f"Found {len(realsense_cameras)} RealSense cameras.") + except ImportError: + logger.warning("Skipping RealSense camera search: pyrealsense2 library not found or not importable.") + except Exception as e: + logger.error(f"Error finding RealSense cameras: {e}") + + return all_realsense_cameras_info + + +def find_and_print_cameras(camera_type_filter: str | None = None) -> List[Dict[str, Any]]: + """ + Finds available cameras based on an optional filter and prints their information. + + Args: + camera_type_filter: Optional string to filter cameras ("realsense" or "opencv"). + If None, lists all cameras. + + Returns: + A list of all available cameras matching the filter, with their metadata. + """ + all_cameras_info: List[Dict[str, Any]] = [] + + if camera_type_filter: + camera_type_filter = camera_type_filter.lower() + + if camera_type_filter is None or camera_type_filter == "opencv": + all_cameras_info.extend(find_all_opencv_cameras()) + if camera_type_filter is None or camera_type_filter == "realsense": + all_cameras_info.extend(find_all_realsense_cameras()) + + if not all_cameras_info: + if camera_type_filter: + logger.warning(f"No {camera_type_filter} cameras were detected.") + else: + logger.warning("No cameras (OpenCV or RealSense) were detected.") + else: + print("\n--- Detected Cameras ---") + for i, cam_info in enumerate(all_cameras_info): + print(f"Camera #{i}:") + for key, value in cam_info.items(): + if key == "default_stream_profile" and isinstance(value, dict): + print(f" {key.replace('_', ' ').capitalize()}:") + for sub_key, sub_value in value.items(): + print(f" {sub_key.capitalize()}: {sub_value}") + else: + print(f" {key.replace('_', ' ').capitalize()}: {value}") + print("-" * 20) + return all_cameras_info + + +def save_image( + img_array: np.ndarray, + camera_identifier: str | int, + images_dir: Path, + camera_type: str, +): + """ + Saves a single image to disk using Pillow. Handles color conversion if necessary. + """ + try: + img = Image.fromarray(img_array, mode="RGB") + + safe_identifier = str(camera_identifier).replace("/", "_").replace("\\", "_") + filename_prefix = f"{camera_type.lower()}_{safe_identifier}" + filename = f"{filename_prefix}.png" + + path = images_dir / filename + path.parent.mkdir(parents=True, exist_ok=True) + img.save(str(path)) + logger.info(f"Saved image: {path}") + except Exception as e: + logger.error(f"Failed to save image for camera {camera_identifier} (type {camera_type}): {e}") + + +def create_camera_instance(cam_meta: Dict[str, Any]) -> Dict[str, Any] | None: + """Create and connect to a camera instance based on metadata.""" + cam_type = cam_meta.get("type") + cam_id = cam_meta.get("id") + instance = None + + logger.info(f"Preparing {cam_type} ID {cam_id} with default profile") + + try: + if cam_type == "OpenCV": + cv_config = OpenCVCameraConfig( + index_or_path=cam_id, + color_mode=ColorMode.RGB, + ) + instance = OpenCVCamera(cv_config) + elif cam_type == "RealSense": + rs_config = RealSenseCameraConfig( + serial_number_or_name=cam_id, + color_mode=ColorMode.RGB, + ) + instance = RealSenseCamera(rs_config) + else: + logger.warning(f"Unknown camera type: {cam_type} for ID {cam_id}. Skipping.") + return None + + if instance: + logger.info(f"Connecting to {cam_type} camera: {cam_id}...") + instance.connect(warmup=False) + return {"instance": instance, "meta": cam_meta} + except Exception as e: + logger.error(f"Failed to connect or configure {cam_type} camera {cam_id}: {e}") + if instance and instance.is_connected: + instance.disconnect() + return None + + +def process_camera_image( + cam_dict: Dict[str, Any], output_dir: Path, current_time: float +) -> concurrent.futures.Future | None: + """Capture and process an image from a single camera.""" + cam = cam_dict["instance"] + meta = cam_dict["meta"] + cam_type_str = str(meta.get("type", "unknown")) + cam_id_str = str(meta.get("id", "unknown")) + + try: + image_data = cam.read() + + return save_image( + image_data, + cam_id_str, + output_dir, + cam_type_str, + ) + except TimeoutError: + logger.warning( + f"Timeout reading from {cam_type_str} camera {cam_id_str} at time {current_time:.2f}s." + ) + except Exception as e: + logger.error(f"Error reading from {cam_type_str} camera {cam_id_str}: {e}") + return None + + +def cleanup_cameras(cameras_to_use: List[Dict[str, Any]]): + """Disconnect all cameras.""" + logger.info(f"Disconnecting {len(cameras_to_use)} cameras...") + for cam_dict in cameras_to_use: + try: + if cam_dict["instance"] and cam_dict["instance"].is_connected: + cam_dict["instance"].disconnect() + except Exception as e: + logger.error(f"Error disconnecting camera {cam_dict['meta'].get('id')}: {e}") + + +def save_images_from_all_cameras( + output_dir: Path, + record_time_s: float = 2.0, + camera_type: str | None = None, +): + """ + Connects to detected cameras (optionally filtered by type) and saves images from each. + Uses default stream profiles for width, height, and FPS. + + Args: + output_dir: Directory to save images. + record_time_s: Duration in seconds to record images. + camera_type: Optional string to filter cameras ("realsense" or "opencv"). + If None, uses all detected cameras. + """ + output_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Saving images to {output_dir}") + all_camera_metadata = find_and_print_cameras(camera_type_filter=camera_type) + + if not all_camera_metadata: + logger.warning("No cameras detected matching the criteria. Cannot save images.") + return + + cameras_to_use = [] + for cam_meta in all_camera_metadata: + camera_instance = create_camera_instance(cam_meta) + if camera_instance: + cameras_to_use.append(camera_instance) + + if not cameras_to_use: + logger.warning("No cameras could be connected. Aborting image save.") + return + + logger.info(f"Starting image capture for {record_time_s} seconds from {len(cameras_to_use)} cameras.") + start_time = time.perf_counter() + + with concurrent.futures.ThreadPoolExecutor(max_workers=len(cameras_to_use) * 2) as executor: + try: + while time.perf_counter() - start_time < record_time_s: + futures = [] + current_capture_time = time.perf_counter() + + for cam_dict in cameras_to_use: + future = process_camera_image(cam_dict, output_dir, current_capture_time) + if future: + futures.append(future) + + if futures: + concurrent.futures.wait(futures) + + except KeyboardInterrupt: + logger.info("Capture interrupted by user.") + finally: + print("\nFinalizing image saving...") + executor.shutdown(wait=True) + cleanup_cameras(cameras_to_use) + print(f"Image capture finished. Images saved to {output_dir}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Unified camera utility script for listing cameras and capturing images." + ) + + parser.add_argument( + "camera_type", + type=str, + nargs="?", + default=None, + choices=["realsense", "opencv"], + help="Specify camera type to capture from (e.g., 'realsense', 'opencv'). Captures from all if omitted.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default="outputs/captured_images", + help="Directory to save images. Default: outputs/captured_images", + ) + parser.add_argument( + "--record-time-s", + type=float, + default=6.0, + help="Time duration to attempt capturing frames. Default: 6 seconds.", + ) + args = parser.parse_args() + save_images_from_all_cameras(**vars(args)) diff --git a/lerobot/find_port.py b/lerobot/find_port.py new file mode 100644 index 000000000..e69de29bb diff --git a/lerobot/record.py b/lerobot/record.py new file mode 100644 index 000000000..ce6f538d5 --- /dev/null +++ b/lerobot/record.py @@ -0,0 +1,347 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Records a dataset. Actions for the robot can be either generated by teleoperation or by a policy. + +Example: + +```shell +python -m lerobot.record \ + --robot.type=so100_follower \ + --robot.port=/dev/tty.usbmodem58760431541 \ + --robot.cameras="{laptop: {type: opencv, camera_index: 0, width: 640, height: 480}}" \ + --robot.id=black \ + --dataset.repo_id=aliberts/record-test \ + --dataset.num_episodes=2 \ + --dataset.single_task="Grab the cube" \ + # <- Teleop optional if you want to teleoperate to record or in between episodes with a policy \ + # --teleop.type=so100_leader \ + # --teleop.port=/dev/tty.usbmodem58760431551 \ + # --teleop.id=blue \ + # <- Policy optional if you want to record with a policy \ + # --policy.path=${HF_USER}/my_policy \ +``` +""" + +import logging +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from pprint import pformat + +import numpy as np +import rerun as rr + +from lerobot.common.cameras import ( # noqa: F401 + CameraConfig, # noqa: F401 +) +from lerobot.common.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 +from lerobot.common.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.common.datasets.image_writer import safe_stop_image_writer +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.utils import build_dataset_frame, hw_to_dataset_features +from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.robots import ( # noqa: F401 + Robot, + RobotConfig, + koch_follower, + make_robot_from_config, + so100_follower, + so101_follower, +) +from lerobot.common.teleoperators import ( # noqa: F401 + Teleoperator, + TeleoperatorConfig, + koch_leader, + make_teleoperator_from_config, + so100_leader, + so101_leader, +) +from lerobot.common.utils.control_utils import ( + init_keyboard_listener, + is_headless, + predict_action, + sanity_check_dataset_name, + sanity_check_dataset_robot_compatibility, +) +from lerobot.common.utils.robot_utils import busy_wait +from lerobot.common.utils.utils import ( + get_safe_torch_device, + init_logging, + log_say, +) +from lerobot.common.utils.visualization_utils import _init_rerun +from lerobot.configs import parser +from lerobot.configs.policies import PreTrainedConfig + + +@dataclass +class DatasetRecordConfig: + # Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`). + repo_id: str + # A short but accurate description of the task performed during the recording (e.g. "Pick the Lego block and drop it in the box on the right.") + single_task: str + # Root directory where the dataset will be stored (e.g. 'dataset/path'). + root: str | Path | None = None + # Limit the frames per second. + fps: int = 30 + # Number of seconds for data recording for each episode. + episode_time_s: int | float = 60 + # Number of seconds for resetting the environment after each episode. + reset_time_s: int | float = 60 + # Number of episodes to record. + num_episodes: int = 50 + # Encode frames in the dataset into video + video: bool = True + # Upload dataset to Hugging Face hub. + push_to_hub: bool = True + # Upload on private repository on the Hugging Face hub. + private: bool = False + # Add tags to your dataset on the hub. + tags: list[str] | None = None + # Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only; + # set to ≥1 to use subprocesses, each using threads to write images. The best number of processes + # and threads depends on your system. We recommend 4 threads per camera with 0 processes. + # If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses. + num_image_writer_processes: int = 0 + # Number of threads writing the frames as png images on disk, per camera. + # Too many threads might cause unstable teleoperation fps due to main thread being blocked. + # Not enough threads might cause low camera fps. + num_image_writer_threads_per_camera: int = 4 + + def __post_init__(self): + if self.single_task is None: + raise ValueError("You need to provide a task as argument in `single_task`.") + + +@dataclass +class RecordConfig: + robot: RobotConfig + dataset: DatasetRecordConfig + # Whether to control the robot with a teleoperator + teleop: TeleoperatorConfig | None = None + # Whether to control the robot with a policy + policy: PreTrainedConfig | None = None + # Display all cameras on screen + display_data: bool = False + # Use vocal synthesis to read events. + play_sounds: bool = True + # Resume recording on an existing dataset. + resume: bool = False + + def __post_init__(self): + # HACK: We parse again the cli args here to get the pretrained path if there was one. + policy_path = parser.get_path_arg("policy") + if policy_path: + cli_overrides = parser.get_cli_overrides("policy") + self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides) + self.policy.pretrained_path = policy_path + + if self.teleop is None and self.policy is None: + raise ValueError("Choose a policy, a teleoperator or both to control the robot") + + @classmethod + def __get_path_fields__(cls) -> list[str]: + """This enables the parser to load config from the policy using `--policy.path=local/dir`""" + return ["policy"] + + +@safe_stop_image_writer +def record_loop( + robot: Robot, + events: dict, + fps: int, + dataset: LeRobotDataset | None = None, + teleop: Teleoperator | None = None, + policy: PreTrainedPolicy | None = None, + control_time_s: int | None = None, + single_task: str | None = None, + display_data: bool = False, +): + if dataset is not None and dataset.fps != fps: + raise ValueError(f"The dataset fps should be equal to requested fps ({dataset.fps} != {fps}).") + + # if policy is given it needs cleaning up + if policy is not None: + policy.reset() + + timestamp = 0 + start_episode_t = time.perf_counter() + while timestamp < control_time_s: + start_loop_t = time.perf_counter() + + if events["exit_early"]: + events["exit_early"] = False + break + + observation = robot.get_observation() + + if policy is not None or dataset is not None: + observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation") + + if policy is not None: + action_values = predict_action( + observation_frame, + policy, + get_safe_torch_device(policy.config.device), + policy.config.use_amp, + task=single_task, + robot_type=robot.robot_type, + ) + action = {key: action_values[i].item() for i, key in enumerate(robot.action_features)} + elif policy is None and teleop is not None: + action = teleop.get_action() + else: + logging.info( + "No policy or teleoperator provided, skipping action generation." + "This is likely to happen when resetting the environment without a teleop device." + "The robot won't be at its rest position at the start of the next episode." + ) + continue + + # Action can eventually be clipped using `max_relative_target`, + # so action actually sent is saved in the dataset. + sent_action = robot.send_action(action) + + if dataset is not None: + action_frame = build_dataset_frame(dataset.features, sent_action, prefix="action") + frame = {**observation_frame, **action_frame} + dataset.add_frame(frame, task=single_task) + + if display_data: + for obs, val in observation.items(): + if isinstance(val, float): + rr.log(f"observation.{obs}", rr.Scalar(val)) + elif isinstance(val, np.ndarray): + rr.log(f"observation.{obs}", rr.Image(val), static=True) + for act, val in action.items(): + if isinstance(val, float): + rr.log(f"action.{act}", rr.Scalar(val)) + + dt_s = time.perf_counter() - start_loop_t + busy_wait(1 / fps - dt_s) + + timestamp = time.perf_counter() - start_episode_t + + +@parser.wrap() +def record(cfg: RecordConfig) -> LeRobotDataset: + init_logging() + logging.info(pformat(asdict(cfg))) + if cfg.display_data: + _init_rerun(session_name="recording") + + robot = make_robot_from_config(cfg.robot) + teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None + + action_features = hw_to_dataset_features(robot.action_features, "action", cfg.dataset.video) + obs_features = hw_to_dataset_features(robot.observation_features, "observation", cfg.dataset.video) + dataset_features = {**action_features, **obs_features} + + if cfg.resume: + dataset = LeRobotDataset( + cfg.dataset.repo_id, + root=cfg.dataset.root, + ) + + if hasattr(robot, "cameras") and len(robot.cameras) > 0: + dataset.start_image_writer( + num_processes=cfg.dataset.num_image_writer_processes, + num_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), + ) + sanity_check_dataset_robot_compatibility(dataset, robot, cfg.dataset.fps, dataset_features) + else: + # Create empty dataset or load existing saved episodes + sanity_check_dataset_name(cfg.dataset.repo_id, cfg.policy) + dataset = LeRobotDataset.create( + cfg.dataset.repo_id, + cfg.dataset.fps, + root=cfg.dataset.root, + robot_type=robot.name, + features=dataset_features, + use_videos=cfg.dataset.video, + image_writer_processes=cfg.dataset.num_image_writer_processes, + image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras), + ) + + # Load pretrained policy + policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) + + robot.connect() + if teleop is not None: + teleop.connect() + + listener, events = init_keyboard_listener() + + recorded_episodes = 0 + while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]: + log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds) + record_loop( + robot=robot, + events=events, + fps=cfg.dataset.fps, + teleop=teleop, + policy=policy, + dataset=dataset, + control_time_s=cfg.dataset.episode_time_s, + single_task=cfg.dataset.single_task, + display_data=cfg.display_data, + ) + + # Execute a few seconds without recording to give time to manually reset the environment + # Skip reset for the last episode to be recorded + if not events["stop_recording"] and ( + (recorded_episodes < cfg.dataset.num_episodes - 1) or events["rerecord_episode"] + ): + log_say("Reset the environment", cfg.play_sounds) + record_loop( + robot=robot, + events=events, + fps=cfg.dataset.fps, + teleop=teleop, + control_time_s=cfg.dataset.reset_time_s, + single_task=cfg.dataset.single_task, + display_data=cfg.display_data, + ) + + if events["rerecord_episode"]: + log_say("Re-record episode", cfg.play_sounds) + events["rerecord_episode"] = False + events["exit_early"] = False + dataset.clear_episode_buffer() + continue + + dataset.save_episode() + recorded_episodes += 1 + + log_say("Stop recording", cfg.play_sounds, blocking=True) + + robot.disconnect() + if teleop is not None: + teleop.disconnect() + + if not is_headless() and listener is not None: + listener.stop() + + if cfg.dataset.push_to_hub: + dataset.push_to_hub(tags=cfg.dataset.tags, private=cfg.dataset.private) + + log_say("Exiting", cfg.play_sounds) + return dataset + + +if __name__ == "__main__": + record() diff --git a/lerobot/replay.py b/lerobot/replay.py new file mode 100644 index 000000000..36eb0864d --- /dev/null +++ b/lerobot/replay.py @@ -0,0 +1,102 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Replays the actions of an episode from a dataset on a robot. + +Example: + +```shell +python -m lerobot.replay \ + --robot.type=so100_follower \ + --robot.port=/dev/tty.usbmodem58760431541 \ + --robot.id=black \ + --dataset.repo_id=aliberts/record-test \ + --dataset.episode=2 +``` +""" + +import logging +import time +from dataclasses import asdict, dataclass +from pathlib import Path +from pprint import pformat + +import draccus + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.robots import ( # noqa: F401 + Robot, + RobotConfig, + koch_follower, + make_robot_from_config, + so100_follower, + so101_follower, +) +from lerobot.common.utils.robot_utils import busy_wait +from lerobot.common.utils.utils import ( + init_logging, + log_say, +) + + +@dataclass +class DatasetReplayConfig: + # Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`). + repo_id: str + # Episode to replay. + episode: int + # Root directory where the dataset will be stored (e.g. 'dataset/path'). + root: str | Path | None = None + # Limit the frames per second. By default, uses the policy fps. + fps: int = 30 + + +@dataclass +class ReplayConfig: + robot: RobotConfig + dataset: DatasetReplayConfig + # Use vocal synthesis to read events. + play_sounds: bool = True + + +@draccus.wrap() +def replay(cfg: ReplayConfig): + init_logging() + logging.info(pformat(asdict(cfg))) + + robot = make_robot_from_config(cfg.robot) + dataset = LeRobotDataset(cfg.dataset.repo_id, root=cfg.dataset.root, episodes=[cfg.dataset.episode]) + actions = dataset.hf_dataset.select_columns("action") + robot.connect() + + log_say("Replaying episode", cfg.play_sounds, blocking=True) + for idx in range(dataset.num_frames): + start_episode_t = time.perf_counter() + + action_array = actions[idx]["action"] + action = {} + for i, name in enumerate(dataset.features["action"]["names"]): + action[name] = action_array[i] + + robot.send_action(action) + + dt_s = time.perf_counter() - start_episode_t + busy_wait(1 / dataset.fps - dt_s) + + robot.disconnect() + + +if __name__ == "__main__": + replay() diff --git a/lerobot/scripts/configure_motor.py b/lerobot/scripts/configure_motor.py deleted file mode 100644 index b0dc8a97d..000000000 --- a/lerobot/scripts/configure_motor.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script configure a single motor at a time to a given ID and baudrate. - -Example of usage: -```bash -python lerobot/scripts/configure_motor.py \ - --port /dev/tty.usbmodem585A0080521 \ - --brand feetech \ - --model sts3215 \ - --baudrate 1000000 \ - --ID 1 -``` -""" - -import argparse -import time - - -def get_motor_bus_cls(brand: str) -> tuple: - if brand == "feetech": - from lerobot.common.robot_devices.motors.configs import FeetechMotorsBusConfig - from lerobot.common.robot_devices.motors.feetech import ( - MODEL_BAUDRATE_TABLE, - SCS_SERIES_BAUDRATE_TABLE, - FeetechMotorsBus, - ) - - return FeetechMotorsBusConfig, FeetechMotorsBus, MODEL_BAUDRATE_TABLE, SCS_SERIES_BAUDRATE_TABLE - - elif brand == "dynamixel": - from lerobot.common.robot_devices.motors.configs import DynamixelMotorsBusConfig - from lerobot.common.robot_devices.motors.dynamixel import ( - MODEL_BAUDRATE_TABLE, - X_SERIES_BAUDRATE_TABLE, - DynamixelMotorsBus, - ) - - return DynamixelMotorsBusConfig, DynamixelMotorsBus, MODEL_BAUDRATE_TABLE, X_SERIES_BAUDRATE_TABLE - - else: - raise ValueError( - f"Currently we do not support this motor brand: {brand}. We currently support feetech and dynamixel motors." - ) - - -def configure_motor(port, brand, model, motor_idx_des, baudrate_des): - motor_bus_config_cls, motor_bus_cls, model_baudrate_table, series_baudrate_table = get_motor_bus_cls( - brand - ) - - # Check if the provided model exists in the model_baud_rate_table - if model not in model_baudrate_table: - raise ValueError( - f"Invalid model '{model}' for brand '{brand}'. Supported models: {list(model_baudrate_table.keys())}" - ) - - # Setup motor names, indices, and models - motor_name = "motor" - motor_index_arbitrary = motor_idx_des # Use the motor ID passed via argument - motor_model = model # Use the motor model passed via argument - - config = motor_bus_config_cls(port=port, motors={motor_name: (motor_index_arbitrary, motor_model)}) - - # Initialize the MotorBus with the correct port and motor configurations - motor_bus = motor_bus_cls(config=config) - - # Try to connect to the motor bus and handle any connection-specific errors - try: - motor_bus.connect() - print(f"Connected on port {motor_bus.port}") - except OSError as e: - print(f"Error occurred when connecting to the motor bus: {e}") - return - - # Motor bus is connected, proceed with the rest of the operations - try: - print("Scanning all baudrates and motor indices") - all_baudrates = set(series_baudrate_table.values()) - motor_index = -1 # Set the motor index to an out-of-range value. - - for baudrate in all_baudrates: - motor_bus.set_bus_baudrate(baudrate) - present_ids = motor_bus.find_motor_indices(list(range(1, 10))) - if len(present_ids) > 1: - raise ValueError( - "Error: More than one motor ID detected. This script is designed to only handle one motor at a time. Please disconnect all but one motor." - ) - - if len(present_ids) == 1: - if motor_index != -1: - raise ValueError( - "Error: More than one motor ID detected. This script is designed to only handle one motor at a time. Please disconnect all but one motor." - ) - motor_index = present_ids[0] - break - - if motor_index == -1: - raise ValueError("No motors detected. Please ensure you have one motor connected.") - - print(f"Motor index found at: {motor_index}") - - if brand == "feetech": - # Allows ID and BAUDRATE to be written in memory - motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0) - - if baudrate != baudrate_des: - print(f"Setting its baudrate to {baudrate_des}") - baudrate_idx = list(series_baudrate_table.values()).index(baudrate_des) - - # The write can fail, so we allow retries - motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx) - time.sleep(0.5) - motor_bus.set_bus_baudrate(baudrate_des) - present_baudrate_idx = motor_bus.read_with_motor_ids( - motor_bus.motor_models, motor_index, "Baud_Rate", num_retry=2 - ) - - if present_baudrate_idx != baudrate_idx: - raise OSError("Failed to write baudrate.") - - print(f"Setting its index to desired index {motor_idx_des}") - if brand == "feetech": - motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0) - motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des) - - present_idx = motor_bus.read_with_motor_ids(motor_bus.motor_models, motor_idx_des, "ID", num_retry=2) - if present_idx != motor_idx_des: - raise OSError("Failed to write index.") - - if brand == "feetech": - # Set Maximum_Acceleration to 254 to speedup acceleration and deceleration of - # the motors. Note: this configuration is not in the official STS3215 Memory Table - motor_bus.write("Lock", 0) - motor_bus.write("Maximum_Acceleration", 254) - - motor_bus.write("Goal_Position", 2048) - time.sleep(4) - print("Present Position", motor_bus.read("Present_Position")) - - motor_bus.write("Offset", 0) - time.sleep(4) - print("Offset", motor_bus.read("Offset")) - - except Exception as e: - print(f"Error occurred during motor configuration: {e}") - - finally: - motor_bus.disconnect() - print("Disconnected from motor bus.") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--port", type=str, required=True, help="Motors bus port (e.g. dynamixel,feetech)") - parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)") - parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)") - parser.add_argument("--ID", type=int, required=True, help="Desired ID of the current motor (e.g. 1,2,3)") - parser.add_argument( - "--baudrate", type=int, default=1000000, help="Desired baudrate for the motor (default: 1000000)" - ) - args = parser.parse_args() - - configure_motor(args.port, args.brand, args.model, args.ID, args.baudrate) diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py deleted file mode 100644 index 3daea98d3..000000000 --- a/lerobot/scripts/control_robot.py +++ /dev/null @@ -1,437 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Utilities to control a robot. - -Useful to record a dataset, replay a recorded episode, run the policy on your robot -and record an evaluation dataset, and to recalibrate your robot if needed. - -Examples of usage: - -- Recalibrate your robot: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=so100 \ - --control.type=calibrate -``` - -- Unlimited teleoperation at highest frequency (~200 Hz is expected), to exit with CTRL+C: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=so100 \ - --robot.cameras='{}' \ - --control.type=teleoperate - -# Add the cameras from the robot definition to visualize them: -python lerobot/scripts/control_robot.py \ - --robot.type=so100 \ - --control.type=teleoperate -``` - -- Unlimited teleoperation at a limited frequency of 30 Hz, to simulate data recording frequency: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=so100 \ - --control.type=teleoperate \ - --control.fps=30 -``` - -- Record one episode in order to test replay: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=so100 \ - --control.type=record \ - --control.fps=30 \ - --control.single_task="Grasp a lego block and put it in the bin." \ - --control.repo_id=$USER/koch_test \ - --control.num_episodes=1 \ - --control.push_to_hub=True -``` - -- Visualize dataset: -```bash -python lerobot/scripts/visualize_dataset.py \ - --repo-id $USER/koch_test \ - --episode-index 0 -``` - -- Replay this test episode: -```bash -python lerobot/scripts/control_robot.py replay \ - --robot.type=so100 \ - --control.type=replay \ - --control.fps=30 \ - --control.repo_id=$USER/koch_test \ - --control.episode=0 -``` - -- Record a full dataset in order to train a policy, with 2 seconds of warmup, -30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes: -```bash -python lerobot/scripts/control_robot.py record \ - --robot.type=so100 \ - --control.type=record \ - --control.fps 30 \ - --control.repo_id=$USER/koch_pick_place_lego \ - --control.num_episodes=50 \ - --control.warmup_time_s=2 \ - --control.episode_time_s=30 \ - --control.reset_time_s=10 -``` - -- For remote controlled robots like LeKiwi, run this script on the robot edge device (e.g. RaspBerryPi): -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=lekiwi \ - --control.type=remote_robot -``` - -**NOTE**: You can use your keyboard to control data recording flow. -- Tap right arrow key '->' to early exit while recording an episode and go to resseting the environment. -- Tap right arrow key '->' to early exit while resetting the environment and got to recording the next episode. -- Tap left arrow key '<-' to early exit and re-record the current episode. -- Tap escape key 'esc' to stop the data recording. -This might require a sudo permission to allow your terminal to monitor keyboard events. - -**NOTE**: You can resume/continue data recording by running the same data recording command and adding `--control.resume=true`. - -- Train on this dataset with the ACT policy: -```bash -python lerobot/scripts/train.py \ - --dataset.repo_id=${HF_USER}/koch_pick_place_lego \ - --policy.type=act \ - --output_dir=outputs/train/act_koch_pick_place_lego \ - --job_name=act_koch_pick_place_lego \ - --device=cuda \ - --wandb.enable=true -``` - -- Run the pretrained policy on the robot: -```bash -python lerobot/scripts/control_robot.py \ - --robot.type=so100 \ - --control.type=record \ - --control.fps=30 \ - --control.single_task="Grasp a lego block and put it in the bin." \ - --control.repo_id=$USER/eval_act_koch_pick_place_lego \ - --control.num_episodes=10 \ - --control.warmup_time_s=2 \ - --control.episode_time_s=30 \ - --control.reset_time_s=10 \ - --control.push_to_hub=true \ - --control.policy.path=outputs/train/act_koch_pick_place_lego/checkpoints/080000/pretrained_model -``` -""" - -import logging -import os -import time -from dataclasses import asdict -from pprint import pformat - -import rerun as rr - -# from safetensors.torch import load_file, save_file -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.policies.factory import make_policy -from lerobot.common.robot_devices.control_configs import ( - CalibrateControlConfig, - ControlConfig, - ControlPipelineConfig, - RecordControlConfig, - RemoteRobotConfig, - ReplayControlConfig, - TeleoperateControlConfig, -) -from lerobot.common.robot_devices.control_utils import ( - control_loop, - init_keyboard_listener, - is_headless, - log_control_info, - record_episode, - reset_environment, - sanity_check_dataset_name, - sanity_check_dataset_robot_compatibility, - stop_recording, - warmup_record, -) -from lerobot.common.robot_devices.robots.utils import Robot, make_robot_from_config -from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect -from lerobot.common.utils.utils import has_method, init_logging, log_say -from lerobot.configs import parser - -######################################################################################## -# Control modes -######################################################################################## - - -@safe_disconnect -def calibrate(robot: Robot, cfg: CalibrateControlConfig): - # TODO(aliberts): move this code in robots' classes - if robot.robot_type.startswith("stretch"): - if not robot.is_connected: - robot.connect() - if not robot.is_homed(): - robot.home() - return - - arms = robot.available_arms if cfg.arms is None else cfg.arms - unknown_arms = [arm_id for arm_id in arms if arm_id not in robot.available_arms] - available_arms_str = " ".join(robot.available_arms) - unknown_arms_str = " ".join(unknown_arms) - - if arms is None or len(arms) == 0: - raise ValueError( - "No arm provided. Use `--arms` as argument with one or more available arms.\n" - f"For instance, to recalibrate all arms add: `--arms {available_arms_str}`" - ) - - if len(unknown_arms) > 0: - raise ValueError( - f"Unknown arms provided ('{unknown_arms_str}'). Available arms are `{available_arms_str}`." - ) - - for arm_id in arms: - arm_calib_path = robot.calibration_dir / f"{arm_id}.json" - if arm_calib_path.exists(): - print(f"Removing '{arm_calib_path}'") - arm_calib_path.unlink() - else: - print(f"Calibration file not found '{arm_calib_path}'") - - if robot.is_connected: - robot.disconnect() - - if robot.robot_type.startswith("lekiwi") and "main_follower" in arms: - print("Calibrating only the lekiwi follower arm 'main_follower'...") - robot.calibrate_follower() - return - - if robot.robot_type.startswith("lekiwi") and "main_leader" in arms: - print("Calibrating only the lekiwi leader arm 'main_leader'...") - robot.calibrate_leader() - return - - # Calling `connect` automatically runs calibration - # when the calibration file is missing - robot.connect() - robot.disconnect() - print("Calibration is done! You can now teleoperate and record datasets!") - - -@safe_disconnect -def teleoperate(robot: Robot, cfg: TeleoperateControlConfig): - control_loop( - robot, - control_time_s=cfg.teleop_time_s, - fps=cfg.fps, - teleoperate=True, - display_data=cfg.display_data, - ) - - -@safe_disconnect -def record( - robot: Robot, - cfg: RecordControlConfig, -) -> LeRobotDataset: - # TODO(rcadene): Add option to record logs - if cfg.resume: - dataset = LeRobotDataset( - cfg.repo_id, - root=cfg.root, - ) - if len(robot.cameras) > 0: - dataset.start_image_writer( - num_processes=cfg.num_image_writer_processes, - num_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras), - ) - sanity_check_dataset_robot_compatibility(dataset, robot, cfg.fps, cfg.video) - else: - # Create empty dataset or load existing saved episodes - sanity_check_dataset_name(cfg.repo_id, cfg.policy) - dataset = LeRobotDataset.create( - cfg.repo_id, - cfg.fps, - root=cfg.root, - robot=robot, - use_videos=cfg.video, - image_writer_processes=cfg.num_image_writer_processes, - image_writer_threads=cfg.num_image_writer_threads_per_camera * len(robot.cameras), - ) - - # Load pretrained policy - policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) - - if not robot.is_connected: - robot.connect() - - listener, events = init_keyboard_listener() - - # Execute a few seconds without recording to: - # 1. teleoperate the robot to move it in starting position if no policy provided, - # 2. give times to the robot devices to connect and start synchronizing, - # 3. place the cameras windows on screen - enable_teleoperation = policy is None - log_say("Warmup record", cfg.play_sounds) - warmup_record(robot, events, enable_teleoperation, cfg.warmup_time_s, cfg.display_data, cfg.fps) - - if has_method(robot, "teleop_safety_stop"): - robot.teleop_safety_stop() - - recorded_episodes = 0 - while True: - if recorded_episodes >= cfg.num_episodes: - break - - log_say(f"Recording episode {dataset.num_episodes}", cfg.play_sounds) - record_episode( - robot=robot, - dataset=dataset, - events=events, - episode_time_s=cfg.episode_time_s, - display_data=cfg.display_data, - policy=policy, - fps=cfg.fps, - single_task=cfg.single_task, - ) - - # Execute a few seconds without recording to give time to manually reset the environment - # Current code logic doesn't allow to teleoperate during this time. - # TODO(rcadene): add an option to enable teleoperation during reset - # Skip reset for the last episode to be recorded - if not events["stop_recording"] and ( - (recorded_episodes < cfg.num_episodes - 1) or events["rerecord_episode"] - ): - log_say("Reset the environment", cfg.play_sounds) - reset_environment(robot, events, cfg.reset_time_s, cfg.fps) - - if events["rerecord_episode"]: - log_say("Re-record episode", cfg.play_sounds) - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue - - dataset.save_episode() - recorded_episodes += 1 - - if events["stop_recording"]: - break - - log_say("Stop recording", cfg.play_sounds, blocking=True) - stop_recording(robot, listener, cfg.display_data) - - if cfg.push_to_hub: - dataset.push_to_hub(tags=cfg.tags, private=cfg.private) - - log_say("Exiting", cfg.play_sounds) - return dataset - - -@safe_disconnect -def replay( - robot: Robot, - cfg: ReplayControlConfig, -): - # TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset - # TODO(rcadene): Add option to record logs - - dataset = LeRobotDataset(cfg.repo_id, root=cfg.root, episodes=[cfg.episode]) - actions = dataset.hf_dataset.select_columns("action") - - if not robot.is_connected: - robot.connect() - - log_say("Replaying episode", cfg.play_sounds, blocking=True) - for idx in range(dataset.num_frames): - start_episode_t = time.perf_counter() - - action = actions[idx]["action"] - robot.send_action(action) - - dt_s = time.perf_counter() - start_episode_t - busy_wait(1 / cfg.fps - dt_s) - - dt_s = time.perf_counter() - start_episode_t - log_control_info(robot, dt_s, fps=cfg.fps) - - -def _init_rerun(control_config: ControlConfig, session_name: str = "lerobot_control_loop") -> None: - """Initializes the Rerun SDK for visualizing the control loop. - - Args: - control_config: Configuration determining data display and robot type. - session_name: Rerun session name. Defaults to "lerobot_control_loop". - - Raises: - ValueError: If viewer IP is missing for non-remote configurations with display enabled. - """ - if (control_config.display_data and not is_headless()) or ( - control_config.display_data and isinstance(control_config, RemoteRobotConfig) - ): - # Configure Rerun flush batch size default to 8KB if not set - batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000") - os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size - - # Initialize Rerun based on configuration - rr.init(session_name) - if isinstance(control_config, RemoteRobotConfig): - viewer_ip = control_config.viewer_ip - viewer_port = control_config.viewer_port - if not viewer_ip or not viewer_port: - raise ValueError( - "Viewer IP & Port are required for remote config. Set via config file/CLI or disable control_config.display_data." - ) - logging.info(f"Connecting to viewer at {viewer_ip}:{viewer_port}") - rr.connect_tcp(f"{viewer_ip}:{viewer_port}") - else: - # Get memory limit for rerun viewer parameters - memory_limit = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "10%") - rr.spawn(memory_limit=memory_limit) - - -@parser.wrap() -def control_robot(cfg: ControlPipelineConfig): - init_logging() - logging.info(pformat(asdict(cfg))) - - robot = make_robot_from_config(cfg.robot) - - # TODO(Steven): Blueprint for fixed window size - - if isinstance(cfg.control, CalibrateControlConfig): - calibrate(robot, cfg.control) - elif isinstance(cfg.control, TeleoperateControlConfig): - _init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_teleop") - teleoperate(robot, cfg.control) - elif isinstance(cfg.control, RecordControlConfig): - _init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_record") - record(robot, cfg.control) - elif isinstance(cfg.control, ReplayControlConfig): - replay(robot, cfg.control) - elif isinstance(cfg.control, RemoteRobotConfig): - from lerobot.common.robot_devices.robots.lekiwi_remote import run_lekiwi - - _init_rerun(control_config=cfg.control, session_name="lerobot_control_loop_remote") - run_lekiwi(cfg.robot) - - if robot.is_connected: - # Disconnect manually to avoid a "Core dump" during process - # termination due to camera threads not properly exiting. - robot.disconnect() - - -if __name__ == "__main__": - control_robot() diff --git a/lerobot/scripts/control_sim_robot.py b/lerobot/scripts/control_sim_robot.py deleted file mode 100644 index 5347822c8..000000000 --- a/lerobot/scripts/control_sim_robot.py +++ /dev/null @@ -1,561 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Utilities to control a robot in simulation. - -Useful to record a dataset, replay a recorded episode and record an evaluation dataset. - -Examples of usage: - - -- Unlimited teleoperation at a limited frequency of 30 Hz, to simulate data recording frequency. - You can modify this value depending on how fast your simulation can run: -```bash -python lerobot/scripts/control_robot.py teleoperate \ - --fps 30 \ - --robot-path lerobot/configs/robot/your_robot_config.yaml \ - --sim-config lerobot/configs/env/your_sim_config.yaml -``` - -- Record one episode in order to test replay: -```bash -python lerobot/scripts/control_sim_robot.py record \ - --robot-path lerobot/configs/robot/your_robot_config.yaml \ - --sim-config lerobot/configs/env/your_sim_config.yaml \ - --fps 30 \ - --repo-id $USER/robot_sim_test \ - --num-episodes 1 \ - --run-compute-stats 0 -``` - -Enable the --push-to-hub 1 to push the recorded dataset to the huggingface hub. - -- Visualize dataset: -```bash -python lerobot/scripts/visualize_dataset.py \ - --repo-id $USER/robot_sim_test \ - --episode-index 0 -``` - -- Replay a sequence of test episodes: -```bash -python lerobot/scripts/control_sim_robot.py replay \ - --robot-path lerobot/configs/robot/your_robot_config.yaml \ - --sim-config lerobot/configs/env/your_sim_config.yaml \ - --fps 30 \ - --repo-id $USER/robot_sim_test \ - --episode 0 -``` -Note: The seed is saved, therefore, during replay we can load the same environment state as the one during collection. - -- Record a full dataset in order to train a policy, -30 seconds of recording for each episode, and 10 seconds to reset the environment in between episodes: -```bash -python lerobot/scripts/control_sim_robot.py record \ - --robot-path lerobot/configs/robot/your_robot_config.yaml \ - --sim-config lerobot/configs/env/your_sim_config.yaml \ - --fps 30 \ - --repo-id $USER/robot_sim_test \ - --num-episodes 50 \ - --episode-time-s 30 \ -``` - -**NOTE**: You can use your keyboard to control data recording flow. -- Tap right arrow key '->' to early exit while recording an episode and go to resetting the environment. -- Tap right arrow key '->' to early exit while resetting the environment and got to recording the next episode. -- Tap left arrow key '<-' to early exit and re-record the current episode. -- Tap escape key 'esc' to stop the data recording. -This might require a sudo permission to allow your terminal to monitor keyboard events. - -**NOTE**: You can resume/continue data recording by running the same data recording command twice. -""" - -import argparse -import importlib -import logging -import time -from pathlib import Path - -import cv2 -import gymnasium as gym -import numpy as np -import torch - -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset -from lerobot.common.robot_devices.control_utils import ( - init_keyboard_listener, - init_policy, - is_headless, - log_control_info, - predict_action, - sanity_check_dataset_name, - sanity_check_dataset_robot_compatibility, - stop_recording, -) -from lerobot.common.robot_devices.robots.utils import Robot, make_robot -from lerobot.common.robot_devices.utils import busy_wait -from lerobot.common.utils.utils import init_hydra_config, init_logging, log_say - -raise NotImplementedError("This script is currently deactivated") - -DEFAULT_FEATURES = { - "next.reward": { - "dtype": "float32", - "shape": (1,), - "names": None, - }, - "next.success": { - "dtype": "bool", - "shape": (1,), - "names": None, - }, - "seed": { - "dtype": "int64", - "shape": (1,), - "names": None, - }, - "timestamp": { - "dtype": "float32", - "shape": (1,), - "names": None, - }, -} - - -######################################################################################## -# Utilities -######################################################################################## -def none_or_int(value): - if value == "None": - return None - return int(value) - - -def init_sim_calibration(robot, cfg): - # Constants necessary for transforming the joint pos of the real robot to the sim - # depending on the robot description used in that sim. - start_pos = np.array(robot.leader_arms.main.calibration["start_pos"]) - axis_directions = np.array(cfg.get("axis_directions", [1])) - offsets = np.array(cfg.get("offsets", [0])) * np.pi - - return {"start_pos": start_pos, "axis_directions": axis_directions, "offsets": offsets} - - -def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets): - """Counts - starting position -> radians -> align axes -> offset""" - return axis_directions * (real_positions - start_pos) * 2.0 * np.pi / 4096 + offsets - - -######################################################################################## -# Control modes -######################################################################################## - - -def teleoperate(env, robot: Robot, process_action_fn, teleop_time_s=None): - env = env() - env.reset() - start_teleop_t = time.perf_counter() - while True: - leader_pos = robot.leader_arms.main.read("Present_Position") - action = process_action_fn(leader_pos) - env.step(np.expand_dims(action, 0)) - if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s: - print("Teleoperation processes finished.") - break - - -def record( - env, - robot: Robot, - process_action_from_leader, - root: Path, - repo_id: str, - task: str, - fps: int | None = None, - tags: list[str] | None = None, - pretrained_policy_name_or_path: str = None, - policy_overrides: bool | None = None, - episode_time_s: int = 30, - num_episodes: int = 50, - video: bool = True, - push_to_hub: bool = True, - num_image_writer_processes: int = 0, - num_image_writer_threads_per_camera: int = 4, - display_cameras: bool = False, - play_sounds: bool = True, - resume: bool = False, - local_files_only: bool = False, - run_compute_stats: bool = True, -) -> LeRobotDataset: - # Load pretrained policy - policy = None - if pretrained_policy_name_or_path is not None: - policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides) - - if fps is None: - fps = policy_fps - logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).") - - if policy is None and process_action_from_leader is None: - raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.") - - # initialize listener before sim env - listener, events = init_keyboard_listener() - - # create sim env - env = env() - - # Create empty dataset or load existing saved episodes - num_cameras = sum([1 if "image" in key else 0 for key in env.observation_space]) - - # get image keys - image_keys = [key for key in env.observation_space if "image" in key] - state_keys_dict = env_cfg.state_keys - - if resume: - dataset = LeRobotDataset( - repo_id, - root=root, - local_files_only=local_files_only, - ) - dataset.start_image_writer( - num_processes=num_image_writer_processes, - num_threads=num_image_writer_threads_per_camera * num_cameras, - ) - sanity_check_dataset_robot_compatibility(dataset, robot, fps, video) - else: - features = DEFAULT_FEATURES - # add image keys to features - for key in image_keys: - shape = env.observation_space[key].shape - if not key.startswith("observation.image."): - key = "observation.image." + key - features[key] = {"dtype": "video", "names": ["channels", "height", "width"], "shape": shape} - - for key, obs_key in state_keys_dict.items(): - features[key] = { - "dtype": "float32", - "names": None, - "shape": env.observation_space[obs_key].shape, - } - - features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None} - - # Create empty dataset or load existing saved episodes - sanity_check_dataset_name(repo_id, policy) - dataset = LeRobotDataset.create( - repo_id, - fps, - root=root, - features=features, - use_videos=video, - image_writer_processes=num_image_writer_processes, - image_writer_threads=num_image_writer_threads_per_camera * num_cameras, - ) - - recorded_episodes = 0 - while True: - log_say(f"Recording episode {dataset.num_episodes}", play_sounds) - - if events is None: - events = {"exit_early": False} - - if episode_time_s is None: - episode_time_s = float("inf") - - timestamp = 0 - start_episode_t = time.perf_counter() - - seed = np.random.randint(0, 1e5) - observation, info = env.reset(seed=seed) - - while timestamp < episode_time_s: - start_loop_t = time.perf_counter() - - if policy is not None: - action = predict_action(observation, policy, device, use_amp) - else: - leader_pos = robot.leader_arms.main.read("Present_Position") - action = process_action_from_leader(leader_pos) - - observation, reward, terminated, _, info = env.step(action) - - success = info.get("is_success", False) - env_timestamp = info.get("timestamp", dataset.episode_buffer["size"] / fps) - - frame = { - "action": torch.from_numpy(action), - "next.reward": reward, - "next.success": success, - "seed": seed, - "timestamp": env_timestamp, - } - - for key in image_keys: - if not key.startswith("observation.image"): - frame["observation.image." + key] = observation[key] - else: - frame[key] = observation[key] - - for key, obs_key in state_keys_dict.items(): - frame[key] = torch.from_numpy(observation[obs_key]) - - dataset.add_frame(frame) - - if display_cameras and not is_headless(): - for key in image_keys: - cv2.imshow(key, cv2.cvtColor(observation[key], cv2.COLOR_RGB2BGR)) - cv2.waitKey(1) - - if fps is not None: - dt_s = time.perf_counter() - start_loop_t - busy_wait(1 / fps - dt_s) - - dt_s = time.perf_counter() - start_loop_t - log_control_info(robot, dt_s, fps=fps) - - timestamp = time.perf_counter() - start_episode_t - if events["exit_early"] or terminated: - events["exit_early"] = False - break - - if events["rerecord_episode"]: - log_say("Re-record episode", play_sounds) - events["rerecord_episode"] = False - events["exit_early"] = False - dataset.clear_episode_buffer() - continue - - dataset.save_episode(task=task) - recorded_episodes += 1 - - if events["stop_recording"] or recorded_episodes >= num_episodes: - break - else: - logging.info("Waiting for a few seconds before starting next episode recording...") - busy_wait(3) - - log_say("Stop recording", play_sounds, blocking=True) - stop_recording(robot, listener, display_cameras) - - if run_compute_stats: - logging.info("Computing dataset statistics") - dataset.consolidate(run_compute_stats) - - if push_to_hub: - dataset.push_to_hub(tags=tags) - - log_say("Exiting", play_sounds) - return dataset - - -def replay( - env, root: Path, repo_id: str, episode: int, fps: int | None = None, local_files_only: bool = True -): - env = env() - - local_dir = Path(root) / repo_id - if not local_dir.exists(): - raise ValueError(local_dir) - - dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only) - items = dataset.hf_dataset.select_columns("action") - seeds = dataset.hf_dataset.select_columns("seed")["seed"] - - from_idx = dataset.episode_data_index["from"][episode].item() - to_idx = dataset.episode_data_index["to"][episode].item() - env.reset(seed=seeds[from_idx].item()) - logging.info("Replaying episode") - log_say("Replaying episode", play_sounds=True) - for idx in range(from_idx, to_idx): - start_episode_t = time.perf_counter() - action = items[idx]["action"] - env.step(action.unsqueeze(0).numpy()) - dt_s = time.perf_counter() - start_episode_t - busy_wait(1 / fps - dt_s) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers(dest="mode", required=True) - - # Set common options for all the subparsers - base_parser = argparse.ArgumentParser(add_help=False) - base_parser.add_argument( - "--robot-path", - type=str, - default="lerobot/configs/robot/koch.yaml", - help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.", - ) - - base_parser.add_argument( - "--sim-config", - help="Path to a yaml config you want to use for initializing a sim environment based on gym ", - ) - - parser_record = subparsers.add_parser("teleoperate", parents=[base_parser]) - - parser_record = subparsers.add_parser("record", parents=[base_parser]) - parser_record.add_argument( - "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" - ) - parser_record.add_argument( - "--root", - type=Path, - default=None, - help="Root directory where the dataset will be stored locally at '{root}/{repo_id}' (e.g. 'data/hf_username/dataset_name').", - ) - parser_record.add_argument( - "--repo-id", - type=str, - default="lerobot/test", - help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).", - ) - parser_record.add_argument( - "--episode-time-s", - type=int, - default=60, - help="Number of seconds for data recording for each episode.", - ) - parser_record.add_argument( - "--task", - type=str, - required=True, - help="A description of the task preformed during recording that can be used as a language instruction.", - ) - parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.") - parser_record.add_argument( - "--run-compute-stats", - type=int, - default=1, - help="By default, run the computation of the data statistics at the end of data collection. Compute intensive and not required to just replay an episode.", - ) - parser_record.add_argument( - "--push-to-hub", - type=int, - default=1, - help="Upload dataset to Hugging Face hub.", - ) - parser_record.add_argument( - "--tags", - type=str, - nargs="*", - help="Add tags to your dataset on the hub.", - ) - parser_record.add_argument( - "--num-image-writer-processes", - type=int, - default=0, - help=( - "Number of subprocesses handling the saving of frames as PNG. Set to 0 to use threads only; " - "set to ≥1 to use subprocesses, each using threads to write images. The best number of processes " - "and threads depends on your system. We recommend 4 threads per camera with 0 processes. " - "If fps is unstable, adjust the thread count. If still unstable, try using 1 or more subprocesses." - ), - ) - parser_record.add_argument( - "--num-image-writer-threads-per-camera", - type=int, - default=4, - help=( - "Number of threads writing the frames as png images on disk, per camera. " - "Too much threads might cause unstable teleoperation fps due to main thread being blocked. " - "Not enough threads might cause low camera fps." - ), - ) - parser_record.add_argument( - "--display-cameras", - type=int, - default=0, - help="Visualize image observations with opencv.", - ) - parser_record.add_argument( - "--resume", - type=int, - default=0, - help="Resume recording on an existing dataset.", - ) - parser_replay = subparsers.add_parser("replay", parents=[base_parser]) - parser_replay.add_argument( - "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" - ) - parser_replay.add_argument( - "--root", - type=Path, - default=None, - help="Root directory where the dataset will be stored locally (e.g. 'data/hf_username/dataset_name'). By default, stored in cache folder.", - ) - parser_replay.add_argument( - "--repo-id", - type=str, - default="lerobot/test", - help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).", - ) - parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episodes to replay.") - - args = parser.parse_args() - - init_logging() - - control_mode = args.mode - robot_path = args.robot_path - env_config_path = args.sim_config - kwargs = vars(args) - del kwargs["mode"] - del kwargs["robot_path"] - del kwargs["sim_config"] - - # make gym env - env_cfg = init_hydra_config(env_config_path) - importlib.import_module(f"gym_{env_cfg.env.type}") - - def env_constructor(): - return gym.make(env_cfg.env.handle, disable_env_checker=True, **env_cfg.env.gym) - - robot = None - process_leader_actions_fn = None - - if control_mode in ["teleoperate", "record"]: - # make robot - robot_overrides = ["~cameras", "~follower_arms"] - # TODO(rcadene): remove - robot_cfg = init_hydra_config(robot_path, robot_overrides) - robot = make_robot(robot_cfg) - robot.connect() - - calib_kwgs = init_sim_calibration(robot, env_cfg.calibration) - - def process_leader_actions_fn(action): - return real_positions_to_sim(action, **calib_kwgs) - - robot.leader_arms.main.calibration = None - - if control_mode == "teleoperate": - teleoperate(env_constructor, robot, process_leader_actions_fn) - - elif control_mode == "record": - record(env_constructor, robot, process_leader_actions_fn, **kwargs) - - elif control_mode == "replay": - replay(env_constructor, **kwargs) - - else: - raise ValueError( - f"Invalid control mode: '{control_mode}', only valid modes are teleoperate, record and replay." - ) - - if robot and robot.is_connected: - # Disconnect manually to avoid a "Core dump" during process - # termination due to camera threads not properly exiting. - robot.disconnect() diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 9790f8b31..58275f666 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -94,8 +94,8 @@ def rollout( data will probably need to be discarded (for environments that aren't the first one to be done). The return dictionary contains: - (optional) "observation": A a dictionary of (batch, sequence + 1, *) tensors mapped to observation - keys. NOTE the that this has an extra sequence element relative to the other keys in the + (optional) "observation": A dictionary of (batch, sequence + 1, *) tensors mapped to observation + keys. NOTE that this has an extra sequence element relative to the other keys in the dictionary. This is because an extra observation is included for after the environment is terminated or truncated. "action": A (batch, sequence, action_dim) tensor of actions applied based on the observations (not diff --git a/lerobot/scripts/find_joint_limits.py b/lerobot/scripts/find_joint_limits.py new file mode 100644 index 000000000..95676dd35 --- /dev/null +++ b/lerobot/scripts/find_joint_limits.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple script to control a robot from teleoperation. + +Example: + +```shell +python -m lerobot.scripts.server.find_joint_limits \ + --robot.type=so100_follower \ + --robot.port=/dev/tty.usbmodem58760431541 \ + --robot.id=black \ + --teleop.type=so100_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ + --teleop.id=blue +``` +""" + +import time +from dataclasses import dataclass + +import draccus +import numpy as np + +from lerobot.common.model.kinematics import RobotKinematics +from lerobot.common.robots import ( # noqa: F401 + RobotConfig, + koch_follower, + make_robot_from_config, + so100_follower, +) +from lerobot.common.teleoperators import ( # noqa: F401 + TeleoperatorConfig, + gamepad, + koch_leader, + make_teleoperator_from_config, + so100_leader, +) + + +@dataclass +class FindJointLimitsConfig: + teleop: TeleoperatorConfig + robot: RobotConfig + # Limit the maximum frames per second. By default, no limit. + teleop_time_s: float = 30 + # Display all cameras on screen + display_data: bool = False + + +@draccus.wrap() +def find_joint_and_ee_bounds(cfg: FindJointLimitsConfig): + teleop = make_teleoperator_from_config(cfg.teleop) + robot = make_robot_from_config(cfg.robot) + + teleop.connect() + robot.connect() + + start_episode_t = time.perf_counter() + robot_type = getattr(robot.config, "robot_type", "so101") + if "so100" in robot_type or "so101" in robot_type: + # Note to be compatible with the rest of the codebase, + # we are using the new calibration method for so101 and so100 + robot_type = "so_new_calibration" + kinematics = RobotKinematics(robot_type=robot_type) + + # Initialize min/max values + observation = robot.get_observation() + joint_positions = np.array([observation[f"{key}.pos"] for key in robot.bus.motors]) + ee_pos = kinematics.forward_kinematics(joint_positions, frame="gripper_tip")[:3, 3] + + max_pos = joint_positions.copy() + min_pos = joint_positions.copy() + max_ee = ee_pos.copy() + min_ee = ee_pos.copy() + + while True: + action = teleop.get_action() + robot.send_action(action) + + observation = robot.get_observation() + joint_positions = np.array([observation[f"{key}.pos"] for key in robot.bus.motors]) + ee_pos = kinematics.forward_kinematics(joint_positions, frame="gripper_tip")[:3, 3] + + # Skip initial warmup period + if (time.perf_counter() - start_episode_t) < 5: + continue + + # Update min/max values + max_ee = np.maximum(max_ee, ee_pos) + min_ee = np.minimum(min_ee, ee_pos) + max_pos = np.maximum(max_pos, joint_positions) + min_pos = np.minimum(min_pos, joint_positions) + + if time.perf_counter() - start_episode_t > cfg.teleop_time_s: + print(f"Max ee position {np.round(max_ee, 4).tolist()}") + print(f"Min ee position {np.round(min_ee, 4).tolist()}") + print(f"Max joint pos position {np.round(max_pos, 4).tolist()}") + print(f"Min joint pos position {np.round(min_pos, 4).tolist()}") + break + + +if __name__ == "__main__": + find_joint_and_ee_bounds() diff --git a/lerobot/scripts/find_motors_bus_port.py b/lerobot/scripts/find_motors_bus_port.py deleted file mode 100644 index 68f2315d7..000000000 --- a/lerobot/scripts/find_motors_bus_port.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -import time -from pathlib import Path - -from serial.tools import list_ports # Part of pyserial library - - -def find_available_ports(): - if os.name == "nt": # Windows - # List COM ports using pyserial - ports = [port.device for port in list_ports.comports()] - else: # Linux/macOS - # List /dev/tty* ports for Unix-based systems - ports = [str(path) for path in Path("/dev").glob("tty*")] - return ports - - -def find_port(): - print("Finding all available ports for the MotorsBus.") - ports_before = find_available_ports() - print("Ports before disconnecting:", ports_before) - - print("Remove the USB cable from your MotorsBus and press Enter when done.") - input() # Wait for user to disconnect the device - - time.sleep(0.5) # Allow some time for port to be released - ports_after = find_available_ports() - ports_diff = list(set(ports_before) - set(ports_after)) - - if len(ports_diff) == 1: - port = ports_diff[0] - print(f"The port of this MotorsBus is '{port}'") - print("Reconnect the USB cable.") - elif len(ports_diff) == 0: - raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).") - else: - raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).") - - -if __name__ == "__main__": - # Helper to find the USB port associated with your MotorsBus. - find_port() diff --git a/lerobot/scripts/push_pretrained.py b/lerobot/scripts/push_pretrained.py deleted file mode 100644 index e3c683f96..000000000 --- a/lerobot/scripts/push_pretrained.py +++ /dev/null @@ -1,71 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Once you have trained a policy with our training script (lerobot/scripts/train.py), use this script to push it -to the hub. - -Example: - -```bash -python lerobot/scripts/push_pretrained.py \ - --pretrained_path=outputs/train/act_aloha_sim_transfer_cube_human/checkpoints/last/pretrained_model \ - --repo_id=lerobot/act_aloha_sim_transfer_cube_human -``` -""" - -from dataclasses import dataclass -from pathlib import Path - -import draccus -from huggingface_hub import HfApi - - -@dataclass -class PushPreTrainedConfig: - pretrained_path: Path - repo_id: str - branch: str | None = None - private: bool = False - exist_ok: bool = False - - -@draccus.wrap() -def main(cfg: PushPreTrainedConfig): - hub_api = HfApi() - hub_api.create_repo( - repo_id=cfg.repo_id, - private=cfg.private, - repo_type="model", - exist_ok=cfg.exist_ok, - ) - if cfg.branch: - hub_api.create_branch( - repo_id=cfg.repo_id, - branch=cfg.branch, - repo_type="model", - exist_ok=cfg.exist_ok, - ) - - hub_api.upload_folder( - repo_id=cfg.repo_id, - folder_path=cfg.pretrained_path, - repo_type="model", - revision=cfg.branch, - ) - - -if __name__ == "__main__": - main() diff --git a/lerobot/scripts/rl/actor.py b/lerobot/scripts/rl/actor.py new file mode 100644 index 000000000..da24d0dc5 --- /dev/null +++ b/lerobot/scripts/rl/actor.py @@ -0,0 +1,709 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Actor server runner for distributed HILSerl robot policy training. + +This script implements the actor component of the distributed HILSerl architecture. +It executes the policy in the robot environment, collects experience, +and sends transitions to the learner server for policy updates. + +Examples of usage: + +- Start an actor server for real robot training with human-in-the-loop intervention: +```bash +python lerobot/scripts/rl/actor.py --config_path lerobot/configs/train_config_hilserl_so100.json +``` + +**NOTE**: The actor server requires a running learner server to connect to. Ensure the learner +server is started before launching the actor. + +**NOTE**: Human intervention is key to HILSerl training. Press the upper right trigger button on the +gamepad to take control of the robot during training. Initially intervene frequently, then gradually +reduce interventions as the policy improves. + +**WORKFLOW**: +1. Determine robot workspace bounds using `find_joint_limits.py` +2. Record demonstrations with `gym_manipulator.py` in record mode +3. Process the dataset and determine camera crops with `crop_dataset_roi.py` +4. Start the learner server with the training configuration +5. Start this actor server with the same configuration +6. Use human interventions to guide policy learning + +For more details on the complete HILSerl training workflow, see: +https://github.com/michel-aractingi/lerobot-hilserl-guide +""" + +import logging +import os +import time +from functools import lru_cache +from queue import Empty + +import grpc +import torch +from torch import nn +from torch.multiprocessing import Event, Queue + +from lerobot.common.cameras import opencv # noqa: F401 +from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.sac.modeling_sac import SACPolicy +from lerobot.common.robots import so100_follower # noqa: F401 +from lerobot.common.teleoperators import gamepad, so101_leader # noqa: F401 +from lerobot.common.transport import services_pb2, services_pb2_grpc +from lerobot.common.transport.utils import ( + bytes_to_state_dict, + python_object_to_bytes, + receive_bytes_in_chunks, + send_bytes_in_chunks, + transitions_to_bytes, +) +from lerobot.common.utils.process import ProcessSignalHandler +from lerobot.common.utils.queue import get_last_item_from_queue +from lerobot.common.utils.random_utils import set_seed +from lerobot.common.utils.robot_utils import busy_wait +from lerobot.common.utils.transition import ( + Transition, + move_state_dict_to_device, + move_transition_to_device, +) +from lerobot.common.utils.utils import ( + TimerManager, + get_safe_torch_device, + init_logging, +) +from lerobot.configs import parser +from lerobot.configs.train import TrainRLServerPipelineConfig +from lerobot.scripts.rl import learner_service +from lerobot.scripts.rl.gym_manipulator import make_robot_env + +ACTOR_SHUTDOWN_TIMEOUT = 30 + + +################################################# +# Main entry point # +################################################# + + +@parser.wrap() +def actor_cli(cfg: TrainRLServerPipelineConfig): + cfg.validate() + display_pid = False + if not use_threads(cfg): + import torch.multiprocessing as mp + + mp.set_start_method("spawn") + display_pid = True + + # Create logs directory to ensure it exists + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"actor_{cfg.job_name}.log") + + # Initialize logging with explicit log file + init_logging(log_file=log_file, display_pid=display_pid) + logging.info(f"Actor logging initialized, writing to {log_file}") + + is_threaded = use_threads(cfg) + shutdown_event = ProcessSignalHandler(is_threaded, display_pid=display_pid).shutdown_event + + learner_client, grpc_channel = learner_service_client( + host=cfg.policy.actor_learner_config.learner_host, + port=cfg.policy.actor_learner_config.learner_port, + ) + + logging.info("[ACTOR] Establishing connection with Learner") + if not establish_learner_connection(learner_client, shutdown_event): + logging.error("[ACTOR] Failed to establish connection with Learner") + return + + if not use_threads(cfg): + # If we use multithreading, we can reuse the channel + grpc_channel.close() + grpc_channel = None + + logging.info("[ACTOR] Connection with Learner established") + + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + + concurrency_entity = None + if use_threads(cfg): + from threading import Thread + + concurrency_entity = Thread + else: + from multiprocessing import Process + + concurrency_entity = Process + + receive_policy_process = concurrency_entity( + target=receive_policy, + args=(cfg, parameters_queue, shutdown_event, grpc_channel), + daemon=True, + ) + + transitions_process = concurrency_entity( + target=send_transitions, + args=(cfg, transitions_queue, shutdown_event, grpc_channel), + daemon=True, + ) + + interactions_process = concurrency_entity( + target=send_interactions, + args=(cfg, interactions_queue, shutdown_event, grpc_channel), + daemon=True, + ) + + transitions_process.start() + interactions_process.start() + receive_policy_process.start() + + act_with_policy( + cfg=cfg, + shutdown_event=shutdown_event, + parameters_queue=parameters_queue, + transitions_queue=transitions_queue, + interactions_queue=interactions_queue, + ) + logging.info("[ACTOR] Policy process joined") + + logging.info("[ACTOR] Closing queues") + transitions_queue.close() + interactions_queue.close() + parameters_queue.close() + + transitions_process.join() + logging.info("[ACTOR] Transitions process joined") + interactions_process.join() + logging.info("[ACTOR] Interactions process joined") + receive_policy_process.join() + logging.info("[ACTOR] Receive policy process joined") + + logging.info("[ACTOR] join queues") + transitions_queue.cancel_join_thread() + interactions_queue.cancel_join_thread() + parameters_queue.cancel_join_thread() + + logging.info("[ACTOR] queues closed") + + +################################################# +# Core algorithm functions # +################################################# + + +def act_with_policy( + cfg: TrainRLServerPipelineConfig, + shutdown_event: any, # Event, + parameters_queue: Queue, + transitions_queue: Queue, + interactions_queue: Queue, +): + """ + Executes policy interaction within the environment. + + This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner. + Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network. + + Args: + cfg: Configuration settings for the interaction process. + shutdown_event: Event to check if the process should shutdown. + parameters_queue: Queue to receive updated network parameters from the learner. + transitions_queue: Queue to send transitions to the learner. + interactions_queue: Queue to send interactions to the learner. + """ + # Initialize logging for multiprocessing + if not use_threads(cfg): + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"actor_policy_{os.getpid()}.log") + init_logging(log_file=log_file, display_pid=True) + logging.info("Actor policy process logging initialized") + + logging.info("make_env online") + + online_env = make_robot_env(cfg=cfg.env) + + set_seed(cfg.seed) + device = get_safe_torch_device(cfg.policy.device, log=True) + + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + + logging.info("make_policy") + + ### Instantiate the policy in both the actor and learner processes + ### To avoid sending a SACPolicy object through the port, we create a policy instance + ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters + policy: SACPolicy = make_policy( + cfg=cfg.policy, + env_cfg=cfg.env, + ) + policy = policy.eval() + assert isinstance(policy, nn.Module) + + obs, info = online_env.reset() + + # NOTE: For the moment we will solely handle the case of a single environment + sum_reward_episode = 0 + list_transition_to_send_to_learner = [] + episode_intervention = False + # Add counters for intervention rate calculation + episode_intervention_steps = 0 + episode_total_steps = 0 + + policy_timer = TimerManager("Policy inference", log=False) + + for interaction_step in range(cfg.policy.online_steps): + start_time = time.perf_counter() + if shutdown_event.is_set(): + logging.info("[ACTOR] Shutting down act_with_policy") + return + + if interaction_step >= cfg.policy.online_step_before_learning: + # Time policy inference and check if it meets FPS requirement + with policy_timer: + action = policy.select_action(batch=obs) + policy_fps = policy_timer.fps_last + + log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step) + + else: + action = online_env.action_space.sample() + + next_obs, reward, done, truncated, info = online_env.step(action) + + sum_reward_episode += float(reward) + # Increment total steps counter for intervention rate + episode_total_steps += 1 + + # NOTE: We override the action if the intervention is True, because the action applied is the intervention action + if "is_intervention" in info and info["is_intervention"]: + # NOTE: The action space for demonstration before hand is with the full action space + # but sometimes for example we want to deactivate the gripper + action = info["action_intervention"] + episode_intervention = True + # Increment intervention steps counter + episode_intervention_steps += 1 + + list_transition_to_send_to_learner.append( + Transition( + state=obs, + action=action, + reward=reward, + next_state=next_obs, + done=done, + truncated=truncated, # TODO: (azouitine) Handle truncation properly + complementary_info=info, + ) + ) + # assign obs to the next obs and continue the rollout + obs = next_obs + + if done or truncated: + logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}") + + update_policy_parameters(policy=policy.actor, parameters_queue=parameters_queue, device=device) + + if len(list_transition_to_send_to_learner) > 0: + push_transitions_to_transport_queue( + transitions=list_transition_to_send_to_learner, + transitions_queue=transitions_queue, + ) + list_transition_to_send_to_learner = [] + + stats = get_frequency_stats(policy_timer) + policy_timer.reset() + + # Calculate intervention rate + intervention_rate = 0.0 + if episode_total_steps > 0: + intervention_rate = episode_intervention_steps / episode_total_steps + + # Send episodic reward to the learner + interactions_queue.put( + python_object_to_bytes( + { + "Episodic reward": sum_reward_episode, + "Interaction step": interaction_step, + "Episode intervention": int(episode_intervention), + "Intervention rate": intervention_rate, + **stats, + } + ) + ) + + # Reset intervention counters + sum_reward_episode = 0.0 + episode_intervention = False + episode_intervention_steps = 0 + episode_total_steps = 0 + obs, info = online_env.reset() + + if cfg.env.fps is not None: + dt_time = time.perf_counter() - start_time + busy_wait(1 / cfg.env.fps - dt_time) + + +################################################# +# Communication Functions - Group all gRPC/messaging functions # +################################################# + + +def establish_learner_connection( + stub: services_pb2_grpc.LearnerServiceStub, + shutdown_event: Event, # type: ignore + attempts: int = 30, +): + """Establish a connection with the learner. + + Args: + stub (services_pb2_grpc.LearnerServiceStub): The stub to use for the connection. + shutdown_event (Event): The event to check if the connection should be established. + attempts (int): The number of attempts to establish the connection. + Returns: + bool: True if the connection is established, False otherwise. + """ + for _ in range(attempts): + if shutdown_event.is_set(): + logging.info("[ACTOR] Shutting down establish_learner_connection") + return False + + # Force a connection attempt and check state + try: + logging.info("[ACTOR] Send ready message to Learner") + if stub.Ready(services_pb2.Empty()) == services_pb2.Empty(): + return True + except grpc.RpcError as e: + logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}") + time.sleep(2) + return False + + +@lru_cache(maxsize=1) +def learner_service_client( + host: str = "127.0.0.1", + port: int = 50051, +) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]: + import json + + """ + Returns a client for the learner service. + + GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection. + So we need to create only one client and reuse it. + """ + + service_config = { + "methodConfig": [ + { + "name": [{}], # Applies to ALL methods in ALL services + "retryPolicy": { + "maxAttempts": 5, # Max retries (total attempts = 5) + "initialBackoff": "0.1s", # First retry after 0.1s + "maxBackoff": "2s", # Max wait time between retries + "backoffMultiplier": 2, # Exponential backoff factor + "retryableStatusCodes": [ + "UNAVAILABLE", + "DEADLINE_EXCEEDED", + ], # Retries on network failures + }, + } + ] + } + + service_config_json = json.dumps(service_config) + + channel = grpc.insecure_channel( + f"{host}:{port}", + options=[ + ("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE), + ("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE), + ("grpc.enable_retries", 1), + ("grpc.service_config", service_config_json), + ], + ) + stub = services_pb2_grpc.LearnerServiceStub(channel) + logging.info("[ACTOR] Learner service client created") + return stub, channel + + +def receive_policy( + cfg: TrainRLServerPipelineConfig, + parameters_queue: Queue, + shutdown_event: Event, # type: ignore + learner_client: services_pb2_grpc.LearnerServiceStub | None = None, + grpc_channel: grpc.Channel | None = None, +): + """Receive parameters from the learner. + + Args: + cfg (TrainRLServerPipelineConfig): The configuration for the actor. + parameters_queue (Queue): The queue to receive the parameters. + shutdown_event (Event): The event to check if the process should shutdown. + """ + logging.info("[ACTOR] Start receiving parameters from the Learner") + if not use_threads(cfg): + # Create a process-specific log file + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"actor_receive_policy_{os.getpid()}.log") + + # Initialize logging with explicit log file + init_logging(log_file=log_file, display_pid=True) + logging.info("Actor receive policy process logging initialized") + + # Setup process handlers to handle shutdown signal + # But use shutdown event from the main process + _ = ProcessSignalHandler(use_threads=False, display_pid=True) + + if grpc_channel is None or learner_client is None: + learner_client, grpc_channel = learner_service_client( + host=cfg.policy.actor_learner_config.learner_host, + port=cfg.policy.actor_learner_config.learner_port, + ) + + try: + iterator = learner_client.StreamParameters(services_pb2.Empty()) + receive_bytes_in_chunks( + iterator, + parameters_queue, + shutdown_event, + log_prefix="[ACTOR] parameters", + ) + + except grpc.RpcError as e: + logging.error(f"[ACTOR] gRPC error: {e}") + + if not use_threads(cfg): + grpc_channel.close() + logging.info("[ACTOR] Received policy loop stopped") + + +def send_transitions( + cfg: TrainRLServerPipelineConfig, + transitions_queue: Queue, + shutdown_event: any, # Event, + learner_client: services_pb2_grpc.LearnerServiceStub | None = None, + grpc_channel: grpc.Channel | None = None, +) -> services_pb2.Empty: + """ + Sends transitions to the learner. + + This function continuously retrieves messages from the queue and processes: + + - Transition Data: + - A batch of transitions (observation, action, reward, next observation) is collected. + - Transitions are moved to the CPU and serialized using PyTorch. + - The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner. + """ + + if not use_threads(cfg): + # Create a process-specific log file + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"actor_transitions_{os.getpid()}.log") + + # Initialize logging with explicit log file + init_logging(log_file=log_file, display_pid=True) + logging.info("Actor transitions process logging initialized") + + if grpc_channel is None or learner_client is None: + learner_client, grpc_channel = learner_service_client( + host=cfg.policy.actor_learner_config.learner_host, + port=cfg.policy.actor_learner_config.learner_port, + ) + + try: + learner_client.SendTransitions( + transitions_stream( + shutdown_event, transitions_queue, cfg.policy.actor_learner_config.queue_get_timeout + ) + ) + except grpc.RpcError as e: + logging.error(f"[ACTOR] gRPC error: {e}") + + logging.info("[ACTOR] Finished streaming transitions") + + if not use_threads(cfg): + grpc_channel.close() + logging.info("[ACTOR] Transitions process stopped") + + +def send_interactions( + cfg: TrainRLServerPipelineConfig, + interactions_queue: Queue, + shutdown_event: Event, # type: ignore + learner_client: services_pb2_grpc.LearnerServiceStub | None = None, + grpc_channel: grpc.Channel | None = None, +) -> services_pb2.Empty: + """ + Sends interactions to the learner. + + This function continuously retrieves messages from the queue and processes: + + - Interaction Messages: + - Contains useful statistics about episodic rewards and policy timings. + - The message is serialized using `pickle` and sent to the learner. + """ + + if not use_threads(cfg): + # Create a process-specific log file + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"actor_interactions_{os.getpid()}.log") + + # Initialize logging with explicit log file + init_logging(log_file=log_file, display_pid=True) + logging.info("Actor interactions process logging initialized") + + # Setup process handlers to handle shutdown signal + # But use shutdown event from the main process + _ = ProcessSignalHandler(use_threads=False, display_pid=True) + + if grpc_channel is None or learner_client is None: + learner_client, grpc_channel = learner_service_client( + host=cfg.policy.actor_learner_config.learner_host, + port=cfg.policy.actor_learner_config.learner_port, + ) + + try: + learner_client.SendInteractions( + interactions_stream( + shutdown_event, interactions_queue, cfg.policy.actor_learner_config.queue_get_timeout + ) + ) + except grpc.RpcError as e: + logging.error(f"[ACTOR] gRPC error: {e}") + + logging.info("[ACTOR] Finished streaming interactions") + + if not use_threads(cfg): + grpc_channel.close() + logging.info("[ACTOR] Interactions process stopped") + + +def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout: float) -> services_pb2.Empty: # type: ignore + while not shutdown_event.is_set(): + try: + message = transitions_queue.get(block=True, timeout=timeout) + except Empty: + logging.debug("[ACTOR] Transition queue is empty") + continue + + yield from send_bytes_in_chunks( + message, services_pb2.Transition, log_prefix="[ACTOR] Send transitions" + ) + + return services_pb2.Empty() + + +def interactions_stream( + shutdown_event: Event, + interactions_queue: Queue, + timeout: float, # type: ignore +) -> services_pb2.Empty: + while not shutdown_event.is_set(): + try: + message = interactions_queue.get(block=True, timeout=timeout) + except Empty: + logging.debug("[ACTOR] Interaction queue is empty") + continue + + yield from send_bytes_in_chunks( + message, + services_pb2.InteractionMessage, + log_prefix="[ACTOR] Send interactions", + ) + + return services_pb2.Empty() + + +################################################# +# Policy functions # +################################################# + + +def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device): + bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False) + if bytes_state_dict is not None: + logging.info("[ACTOR] Load new parameters from Learner.") + state_dict = bytes_to_state_dict(bytes_state_dict) + state_dict = move_state_dict_to_device(state_dict, device=device) + policy.load_state_dict(state_dict) + + +################################################# +# Utilities functions # +################################################# + + +def push_transitions_to_transport_queue(transitions: list, transitions_queue): + """Send transitions to learner in smaller chunks to avoid network issues. + + Args: + transitions: List of transitions to send + message_queue: Queue to send messages to learner + chunk_size: Size of each chunk to send + """ + transition_to_send_to_learner = [] + for transition in transitions: + tr = move_transition_to_device(transition=transition, device="cpu") + for key, value in tr["state"].items(): + if torch.isnan(value).any(): + logging.warning(f"Found NaN values in transition {key}") + + transition_to_send_to_learner.append(tr) + + transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner)) + + +def get_frequency_stats(timer: TimerManager) -> dict[str, float]: + """Get the frequency statistics of the policy. + + Args: + timer (TimerManager): The timer with collected metrics. + + Returns: + dict[str, float]: The frequency statistics of the policy. + """ + stats = {} + if timer.count > 1: + avg_fps = timer.fps_avg + p90_fps = timer.fps_percentile(90) + logging.debug(f"[ACTOR] Average policy frame rate: {avg_fps}") + logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {p90_fps}") + stats = { + "Policy frequency [Hz]": avg_fps, + "Policy frequency 90th-p [Hz]": p90_fps, + } + return stats + + +def log_policy_frequency_issue(policy_fps: float, cfg: TrainRLServerPipelineConfig, interaction_step: int): + if policy_fps < cfg.env.fps: + logging.warning( + f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}" + ) + + +def use_threads(cfg: TrainRLServerPipelineConfig) -> bool: + return cfg.policy.concurrency.actor == "threads" + + +if __name__ == "__main__": + actor_cli() diff --git a/lerobot/scripts/rl/crop_dataset_roi.py b/lerobot/scripts/rl/crop_dataset_roi.py new file mode 100644 index 000000000..5b7038de3 --- /dev/null +++ b/lerobot/scripts/rl/crop_dataset_roi.py @@ -0,0 +1,314 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +from copy import deepcopy +from pathlib import Path +from typing import Dict, Tuple + +import cv2 + +# import torch.nn.functional as F # noqa: N812 +import torchvision.transforms.functional as F # type: ignore # noqa: N812 +from tqdm import tqdm # type: ignore + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + +def select_rect_roi(img): + """ + Allows the user to draw a rectangular ROI on the image. + + The user must click and drag to draw the rectangle. + - While dragging, the rectangle is dynamically drawn. + - On mouse button release, the rectangle is fixed. + - Press 'c' to confirm the selection. + - Press 'r' to reset the selection. + - Press ESC to cancel. + + Returns: + A tuple (top, left, height, width) representing the rectangular ROI, + or None if no valid ROI is selected. + """ + # Create a working copy of the image + clone = img.copy() + working_img = clone.copy() + + roi = None # Will store the final ROI as (top, left, height, width) + drawing = False + index_x, index_y = -1, -1 # Initial click coordinates + + def mouse_callback(event, x, y, flags, param): + nonlocal index_x, index_y, drawing, roi, working_img + + if event == cv2.EVENT_LBUTTONDOWN: + # Start drawing: record starting coordinates + drawing = True + index_x, index_y = x, y + + elif event == cv2.EVENT_MOUSEMOVE: + if drawing: + # Compute the top-left and bottom-right corners regardless of drag direction + top = min(index_y, y) + left = min(index_x, x) + bottom = max(index_y, y) + right = max(index_x, x) + # Show a temporary image with the current rectangle drawn + temp = working_img.copy() + cv2.rectangle(temp, (left, top), (right, bottom), (0, 255, 0), 2) + cv2.imshow("Select ROI", temp) + + elif event == cv2.EVENT_LBUTTONUP: + # Finish drawing + drawing = False + top = min(index_y, y) + left = min(index_x, x) + bottom = max(index_y, y) + right = max(index_x, x) + height = bottom - top + width = right - left + roi = (top, left, height, width) # (top, left, height, width) + # Draw the final rectangle on the working image and display it + working_img = clone.copy() + cv2.rectangle(working_img, (left, top), (right, bottom), (0, 255, 0), 2) + cv2.imshow("Select ROI", working_img) + + # Create the window and set the callback + cv2.namedWindow("Select ROI") + cv2.setMouseCallback("Select ROI", mouse_callback) + cv2.imshow("Select ROI", working_img) + + print("Instructions for ROI selection:") + print(" - Click and drag to draw a rectangular ROI.") + print(" - Press 'c' to confirm the selection.") + print(" - Press 'r' to reset and draw again.") + print(" - Press ESC to cancel the selection.") + + # Wait until the user confirms with 'c', resets with 'r', or cancels with ESC + while True: + key = cv2.waitKey(1) & 0xFF + # Confirm ROI if one has been drawn + if key == ord("c") and roi is not None: + break + # Reset: clear the ROI and restore the original image + elif key == ord("r"): + working_img = clone.copy() + roi = None + cv2.imshow("Select ROI", working_img) + # Cancel selection for this image + elif key == 27: # ESC key + roi = None + break + + cv2.destroyWindow("Select ROI") + return roi + + +def select_square_roi_for_images(images: dict) -> dict: + """ + For each image in the provided dictionary, open a window to allow the user + to select a rectangular ROI. Returns a dictionary mapping each key to a tuple + (top, left, height, width) representing the ROI. + + Parameters: + images (dict): Dictionary where keys are identifiers and values are OpenCV images. + + Returns: + dict: Mapping of image keys to the selected rectangular ROI. + """ + selected_rois = {} + + for key, img in images.items(): + if img is None: + print(f"Image for key '{key}' is None, skipping.") + continue + + print(f"\nSelect rectangular ROI for image with key: '{key}'") + roi = select_rect_roi(img) + + if roi is None: + print(f"No valid ROI selected for '{key}'.") + else: + selected_rois[key] = roi + print(f"ROI for '{key}': {roi}") + + return selected_rois + + +def get_image_from_lerobot_dataset(dataset: LeRobotDataset): + """ + Find the first row in the dataset and extract the image in order to be used for the crop. + """ + row = dataset[0] + image_dict = {} + for k in row: + if "image" in k: + image_dict[k] = deepcopy(row[k]) + return image_dict + + +def convert_lerobot_dataset_to_cropper_lerobot_dataset( + original_dataset: LeRobotDataset, + crop_params_dict: Dict[str, Tuple[int, int, int, int]], + new_repo_id: str, + new_dataset_root: str, + resize_size: Tuple[int, int] = (128, 128), + push_to_hub: bool = False, + task: str = "", +) -> LeRobotDataset: + """ + Converts an existing LeRobotDataset by iterating over its episodes and frames, + applying cropping and resizing to image observations, and saving a new dataset + with the transformed data. + + Args: + original_dataset (LeRobotDataset): The source dataset. + crop_params_dict (Dict[str, Tuple[int, int, int, int]]): + A dictionary mapping observation keys to crop parameters (top, left, height, width). + new_repo_id (str): Repository id for the new dataset. + new_dataset_root (str): The root directory where the new dataset will be written. + resize_size (Tuple[int, int], optional): The target size (height, width) after cropping. + Defaults to (128, 128). + + Returns: + LeRobotDataset: A new LeRobotDataset where the specified image observations have been cropped + and resized. + """ + # 1. Create a new (empty) LeRobotDataset for writing. + new_dataset = LeRobotDataset.create( + repo_id=new_repo_id, + fps=original_dataset.fps, + root=new_dataset_root, + robot_type=original_dataset.meta.robot_type, + features=original_dataset.meta.info["features"], + use_videos=len(original_dataset.meta.video_keys) > 0, + ) + + # Update the metadata for every image key that will be cropped: + # (Here we simply set the shape to be the final resize_size.) + for key in crop_params_dict: + if key in new_dataset.meta.info["features"]: + new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size) + + # TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset + prev_episode_index = 0 + for frame_idx in tqdm(range(len(original_dataset))): + frame = original_dataset[frame_idx] + + # Create a copy of the frame to add to the new dataset + new_frame = {} + for key, value in frame.items(): + if key in ("task_index", "timestamp", "episode_index", "frame_index", "index", "task"): + continue + if key in ("next.done", "next.reward"): + # if not isinstance(value, str) and len(value.shape) == 0: + value = value.unsqueeze(0) + + if key in crop_params_dict: + top, left, height, width = crop_params_dict[key] + # Apply crop then resize. + cropped = F.crop(value, top, left, height, width) + value = F.resize(cropped, resize_size) + value = value.clamp(0, 1) + + new_frame[key] = value + + new_dataset.add_frame(new_frame, task=task) + + if frame["episode_index"].item() != prev_episode_index: + # Save the episode + new_dataset.save_episode() + prev_episode_index = frame["episode_index"].item() + + # Save the last episode + new_dataset.save_episode() + + if push_to_hub: + new_dataset.push_to_hub() + + return new_dataset + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.") + parser.add_argument( + "--repo-id", + type=str, + default="lerobot", + help="The repository id of the LeRobot dataset to process.", + ) + parser.add_argument( + "--root", + type=str, + default=None, + help="The root directory of the LeRobot dataset.", + ) + parser.add_argument( + "--crop-params-path", + type=str, + default=None, + help="The path to the JSON file containing the ROIs.", + ) + parser.add_argument( + "--push-to-hub", + type=bool, + default=False, + help="Whether to push the new dataset to the hub.", + ) + parser.add_argument( + "--task", + type=str, + default="", + help="The natural language task to describe the dataset.", + ) + args = parser.parse_args() + + dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root) + + images = get_image_from_lerobot_dataset(dataset) + images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()} + images = {k: (v * 255).astype("uint8") for k, v in images.items()} + + if args.crop_params_path is None: + rois = select_square_roi_for_images(images) + else: + with open(args.crop_params_path) as f: + rois = json.load(f) + + # Print the selected rectangular ROIs + print("\nSelected Rectangular Regions of Interest (top, left, height, width):") + for key, roi in rois.items(): + print(f"{key}: {roi}") + + new_repo_id = args.repo_id + "_cropped_resized" + new_dataset_root = Path(str(dataset.root) + "_cropped_resized") + + cropped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset( + original_dataset=dataset, + crop_params_dict=rois, + new_repo_id=new_repo_id, + new_dataset_root=new_dataset_root, + resize_size=(128, 128), + push_to_hub=args.push_to_hub, + task=args.task, + ) + + meta_dir = new_dataset_root / "meta" + meta_dir.mkdir(exist_ok=True) + + with open(meta_dir / "crop_params.json", "w") as f: + json.dump(rois, f, indent=4) diff --git a/lerobot/scripts/rl/eval_policy.py b/lerobot/scripts/rl/eval_policy.py new file mode 100644 index 000000000..3762719bf --- /dev/null +++ b/lerobot/scripts/rl/eval_policy.py @@ -0,0 +1,74 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from lerobot.common.cameras import opencv # noqa: F401 +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.policies.factory import make_policy +from lerobot.common.robots import ( # noqa: F401 + RobotConfig, + make_robot_from_config, + so100_follower, +) +from lerobot.common.teleoperators import ( + gamepad, # noqa: F401 + so101_leader, # noqa: F401 +) +from lerobot.configs import parser +from lerobot.configs.train import TrainRLServerPipelineConfig +from lerobot.scripts.rl.gym_manipulator import make_robot_env + +logging.basicConfig(level=logging.INFO) + + +def eval_policy(env, policy, n_episodes): + sum_reward_episode = [] + for _ in range(n_episodes): + obs, _ = env.reset() + episode_reward = 0.0 + while True: + action = policy.select_action(obs) + obs, reward, terminated, truncated, _ = env.step(action) + episode_reward += reward + if terminated or truncated: + break + sum_reward_episode.append(episode_reward) + + logging.info(f"Success after 20 steps {sum_reward_episode}") + logging.info(f"success rate {sum(sum_reward_episode) / len(sum_reward_episode)}") + + +@parser.wrap() +def main(cfg: TrainRLServerPipelineConfig): + env_cfg = cfg.env + env = make_robot_env(env_cfg) + dataset_cfg = cfg.dataset + dataset = LeRobotDataset(repo_id=dataset_cfg.repo_id) + dataset_meta = dataset.meta + + policy = make_policy( + cfg=cfg.policy, + # env_cfg=cfg.env, + ds_meta=dataset_meta, + ) + policy.from_pretrained(env_cfg.pretrained_policy_name_or_path) + policy.eval() + + eval_policy(env, policy=policy, n_episodes=10) + + +if __name__ == "__main__": + main() diff --git a/lerobot/scripts/rl/gym_manipulator.py b/lerobot/scripts/rl/gym_manipulator.py new file mode 100644 index 000000000..e7327d96d --- /dev/null +++ b/lerobot/scripts/rl/gym_manipulator.py @@ -0,0 +1,2266 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Robot Environment for LeRobot Manipulation Tasks + +This module provides a comprehensive gym-compatible environment for robot manipulation +with support for: +- Multiple robot types (SO100, SO101, Koch and Moss) +- Human intervention via leader-follower control or gamepad + +- End-effector and joint space control +- Image processing (cropping and resizing) + +The environment is built using a composable wrapper pattern where each wrapper +adds specific functionality to the base RobotEnv. + +Example: + env = make_robot_env(cfg) + obs, info = env.reset() + action = policy.select_action(obs) + obs, reward, terminated, truncated, info = env.step(action) +""" + +import logging +import time +from collections import deque +from threading import Lock +from typing import Annotated, Any, Sequence + +import gymnasium as gym +import numpy as np +import torch +import torchvision.transforms.functional as F # noqa: N812 + +from lerobot.common.cameras import opencv # noqa: F401 +from lerobot.common.envs.configs import EnvConfig +from lerobot.common.envs.utils import preprocess_observation +from lerobot.common.model.kinematics import RobotKinematics +from lerobot.common.robots import ( # noqa: F401 + RobotConfig, + make_robot_from_config, + so100_follower, +) +from lerobot.common.teleoperators import ( + gamepad, # noqa: F401 + keyboard, # noqa: F401 + make_teleoperator_from_config, + so101_leader, # noqa: F401 +) +from lerobot.common.teleoperators.gamepad.teleop_gamepad import GamepadTeleop +from lerobot.common.teleoperators.keyboard.teleop_keyboard import KeyboardEndEffectorTeleop +from lerobot.common.utils.robot_utils import busy_wait +from lerobot.common.utils.utils import log_say +from lerobot.configs import parser + +logging.basicConfig(level=logging.INFO) + + +def reset_follower_position(robot_arm, target_position): + current_position_dict = robot_arm.bus.sync_read("Present_Position") + current_position = np.array( + [current_position_dict[name] for name in current_position_dict], dtype=np.float32 + ) + trajectory = torch.from_numpy( + np.linspace(current_position, target_position, 50) + ) # NOTE: 30 is just an arbitrary number + for pose in trajectory: + action_dict = dict(zip(current_position_dict, pose, strict=False)) + robot_arm.bus.sync_write("Goal_Position", action_dict) + busy_wait(0.015) + + +class TorchBox(gym.spaces.Box): + """ + A version of gym.spaces.Box that handles PyTorch tensors. + + This class extends gym.spaces.Box to work with PyTorch tensors, + providing compatibility between NumPy arrays and PyTorch tensors. + """ + + def __init__( + self, + low: float | Sequence[float] | np.ndarray, + high: float | Sequence[float] | np.ndarray, + shape: Sequence[int] | None = None, + np_dtype: np.dtype | type = np.float32, + torch_dtype: torch.dtype = torch.float32, + device: str = "cpu", + seed: int | np.random.Generator | None = None, + ) -> None: + """ + Initialize the PyTorch-compatible Box space. + + Args: + low: Lower bounds of the space. + high: Upper bounds of the space. + shape: Shape of the space. If None, inferred from low and high. + np_dtype: NumPy data type for internal storage. + torch_dtype: PyTorch data type for tensor conversion. + device: PyTorch device for returned tensors. + seed: Random seed for sampling. + """ + super().__init__(low, high, shape=shape, dtype=np_dtype, seed=seed) + self.torch_dtype = torch_dtype + self.device = device + + def sample(self) -> torch.Tensor: + """ + Sample a random point from the space. + + Returns: + A PyTorch tensor within the space bounds. + """ + arr = super().sample() + return torch.as_tensor(arr, dtype=self.torch_dtype, device=self.device) + + def contains(self, x: torch.Tensor) -> bool: + """ + Check if a tensor is within the space bounds. + + Args: + x: The PyTorch tensor to check. + + Returns: + Boolean indicating whether the tensor is within bounds. + """ + # Move to CPU/numpy and cast to the internal dtype + arr = x.detach().cpu().numpy().astype(self.dtype, copy=False) + return super().contains(arr) + + def seed(self, seed: int | np.random.Generator | None = None): + """ + Set the random seed for sampling. + + Args: + seed: The random seed to use. + + Returns: + List containing the seed. + """ + super().seed(seed) + return [seed] + + def __repr__(self) -> str: + """ + Return a string representation of the space. + + Returns: + Formatted string with space details. + """ + return ( + f"TorchBox({self.low_repr}, {self.high_repr}, {self.shape}, " + f"np={self.dtype.name}, torch={self.torch_dtype}, device={self.device})" + ) + + +class TorchActionWrapper(gym.Wrapper): + """ + Wrapper that changes the action space to use PyTorch tensors. + + This wrapper modifies the action space to return PyTorch tensors when sampled + and handles converting PyTorch actions to NumPy when stepping the environment. + """ + + def __init__(self, env: gym.Env, device: str): + """ + Initialize the PyTorch action space wrapper. + + Args: + env: The environment to wrap. + device: The PyTorch device to use for tensor operations. + """ + super().__init__(env) + self.action_space = TorchBox( + low=env.action_space.low, + high=env.action_space.high, + shape=env.action_space.shape, + torch_dtype=torch.float32, + device=torch.device("cpu"), + ) + + def step(self, action: torch.Tensor): + """ + Step the environment with a PyTorch tensor action. + + This method handles conversion from PyTorch tensors to NumPy arrays + for compatibility with the underlying environment. + + Args: + action: PyTorch tensor action to take. + + Returns: + Tuple of (observation, reward, terminated, truncated, info). + """ + if action.dim() == 2: + action = action.squeeze(0) + action = action.detach().cpu().numpy() + return self.env.step(action) + + +class RobotEnv(gym.Env): + """ + Gym-compatible environment for evaluating robotic control policies with integrated human intervention. + + This environment wraps a robot interface to provide a consistent API for policy evaluation. It supports both relative (delta) + and absolute joint position commands and automatically configures its observation and action spaces based on the robot's + sensors and configuration. + """ + + def __init__( + self, + robot, + use_gripper: bool = False, + display_cameras: bool = False, + ): + """ + Initialize the RobotEnv environment. + + The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup + supports both relative (delta) adjustments and absolute joint positions for controlling the robot. + + Args: + robot: The robot interface object used to connect and interact with the physical robot. + display_cameras: If True, the robot's camera feeds will be displayed during execution. + """ + super().__init__() + + self.robot = robot + self.display_cameras = display_cameras + + # Connect to the robot if not already connected. + if not self.robot.is_connected: + self.robot.connect() + + # Episode tracking. + self.current_step = 0 + self.episode_data = None + + self._joint_names = [f"{key}.pos" for key in self.robot.bus.motors] + self._image_keys = self.robot.cameras.keys() + + # Read initial joint positions using the bus + self.current_joint_positions = self._get_observation()["agent_pos"] + + self.use_gripper = use_gripper + + self._setup_spaces() + + def _get_observation(self) -> np.ndarray: + """Helper to convert a dictionary from bus.sync_read to an ordered numpy array.""" + obs_dict = self.robot.get_observation() + joint_positions = np.array([obs_dict[name] for name in self._joint_names], dtype=np.float32) + + images = {key: obs_dict[key] for key in self._image_keys} + return {"agent_pos": joint_positions, "pixels": images} + + def _setup_spaces(self): + """ + Dynamically configure the observation and action spaces based on the robot's capabilities. + + Observation Space: + - For keys with "image": A Box space with pixel values ranging from 0 to 255. + - For non-image keys: A nested Dict space is created under 'observation.state' with a suitable range. + + Action Space: + - The action space is defined as a Box space representing joint position commands. It is defined as relative (delta) + or absolute, based on the configuration. + """ + example_obs = self._get_observation() + + observation_spaces = {} + + # Define observation spaces for images and other states. + if "pixels" in example_obs: + prefix = "observation.images" if len(example_obs["pixels"]) > 1 else "observation.image" + observation_spaces = { + f"{prefix}.{key}": gym.spaces.Box( + low=0, high=255, shape=example_obs["pixels"][key].shape, dtype=np.uint8 + ) + for key in example_obs["pixels"] + } + + observation_spaces["observation.state"] = gym.spaces.Box( + low=0, + high=10, + shape=example_obs["agent_pos"].shape, + dtype=np.float32, + ) + + self.observation_space = gym.spaces.Dict(observation_spaces) + + # Define the action space for joint positions along with setting an intervention flag. + action_dim = 3 + bounds = {} + bounds["min"] = -np.ones(action_dim) + bounds["max"] = np.ones(action_dim) + + if self.use_gripper: + action_dim += 1 + bounds["min"] = np.concatenate([bounds["min"], [0]]) + bounds["max"] = np.concatenate([bounds["max"], [2]]) + + self.action_space = gym.spaces.Box( + low=bounds["min"], + high=bounds["max"], + shape=(action_dim,), + dtype=np.float32, + ) + + def reset(self, seed=None, options=None) -> tuple[dict[str, np.ndarray], dict[str, Any]]: + """ + Reset the environment to its initial state. + This method resets the step counter and clears any episodic data. + + Args: + seed: A seed for random number generation to ensure reproducibility. + options: Additional options to influence the reset behavior. + + Returns: + A tuple containing: + - observation (dict): The initial sensor observation. + - info (dict): A dictionary with supplementary information, including the key "is_intervention". + """ + super().reset(seed=seed, options=options) + + self.robot.reset() + + # Capture the initial observation. + observation = self._get_observation() + + # Reset episode tracking variables. + self.current_step = 0 + self.episode_data = None + + return observation, {"is_intervention": False} + + def step(self, action) -> tuple[dict[str, np.ndarray], float, bool, bool, dict[str, Any]]: + """ + Execute a single step within the environment using the specified action. + + The provided action is processed and sent to the robot as joint position commands + that may be either absolute values or deltas based on the environment configuration. + + Args: + action: The commanded joint positions as a numpy array or torch tensor. + + Returns: + A tuple containing: + - observation (dict): The new sensor observation after taking the step. + - reward (float): The step reward (default is 0.0 within this wrapper). + - terminated (bool): True if the episode has reached a terminal state. + - truncated (bool): True if the episode was truncated (e.g., time constraints). + - info (dict): Additional debugging information including intervention status. + """ + self.current_joint_positions = self._get_observation()["agent_pos"] + + action_dict = {"delta_x": action[0], "delta_y": action[1], "delta_z": action[2]} + + # 1.0 action corresponds to no-op action + action_dict["gripper"] = action[3] if self.use_gripper else 1.0 + + self.robot.send_action(action_dict) + + if self.display_cameras: + self.render() + + self.current_step += 1 + + reward = 0.0 + terminated = False + truncated = False + + return ( + self._get_observation(), + reward, + terminated, + truncated, + {"is_intervention": False}, + ) + + def render(self): + """ + Render the current state of the environment by displaying the robot's camera feeds. + """ + import cv2 + + observation = self._get_observation() + image_keys = [key for key in observation if "image" in key] + + for key in image_keys: + cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) + cv2.waitKey(1) + + def close(self): + """ + Close the environment and clean up resources by disconnecting the robot. + + If the robot is currently connected, this method properly terminates the connection to ensure that all + associated resources are released. + """ + if self.robot.is_connected: + self.robot.disconnect() + + +class AddJointVelocityToObservation(gym.ObservationWrapper): + """ + Wrapper that adds joint velocity information to the observation. + + This wrapper computes joint velocities by tracking changes in joint positions over time, + and extends the observation space to include these velocities. + """ + + def __init__(self, env, joint_velocity_limits=100.0, fps=30, num_dof=6): + """ + Initialize the joint velocity wrapper. + + Args: + env: The environment to wrap. + joint_velocity_limits: Maximum expected joint velocity for space bounds. + fps: Frames per second used to calculate velocity (position delta / time). + num_dof: Number of degrees of freedom (joints) in the robot. + """ + super().__init__(env) + + # Extend observation space to include joint velocities + old_low = self.observation_space["observation.state"].low + old_high = self.observation_space["observation.state"].high + old_shape = self.observation_space["observation.state"].shape + + self.last_joint_positions = np.zeros(num_dof) + + new_low = np.concatenate([old_low, np.ones(num_dof) * -joint_velocity_limits]) + new_high = np.concatenate([old_high, np.ones(num_dof) * joint_velocity_limits]) + + new_shape = (old_shape[0] + num_dof,) + + self.observation_space["observation.state"] = gym.spaces.Box( + low=new_low, + high=new_high, + shape=new_shape, + dtype=np.float32, + ) + + self.dt = 1.0 / fps + + def observation(self, observation): + """ + Add joint velocity information to the observation. + + Args: + observation: The original observation from the environment. + + Returns: + The modified observation with joint velocities. + """ + joint_velocities = (observation["agent_pos"] - self.last_joint_positions) / self.dt + self.last_joint_positions = observation["agent_pos"] + observation["agent_pos"] = np.concatenate([observation["agent_pos"], joint_velocities], axis=-1) + return observation + + +class AddCurrentToObservation(gym.ObservationWrapper): + """ + Wrapper that adds motor current information to the observation. + + This wrapper extends the observation space to include the current values + from each motor, providing information about the forces being applied. + """ + + def __init__(self, env, max_current=500, num_dof=6): + """ + Initialize the current observation wrapper. + + Args: + env: The environment to wrap. + max_current: Maximum expected current for space bounds. + num_dof: Number of degrees of freedom (joints) in the robot. + """ + super().__init__(env) + + # Extend observation space to include joint velocities + old_low = self.observation_space["observation.state"].low + old_high = self.observation_space["observation.state"].high + old_shape = self.observation_space["observation.state"].shape + + new_low = np.concatenate([old_low, np.zeros(num_dof)]) + new_high = np.concatenate([old_high, np.ones(num_dof) * max_current]) + + new_shape = (old_shape[0] + num_dof,) + + self.observation_space["observation.state"] = gym.spaces.Box( + low=new_low, + high=new_high, + shape=new_shape, + dtype=np.float32, + ) + + def observation(self, observation): + """ + Add current information to the observation. + + Args: + observation: The original observation from the environment. + + Returns: + The modified observation with current values. + """ + present_current_observation = self.unwrapped._get_observation()["agent_pos"] + observation["agent_pos"] = np.concatenate( + [observation["agent_pos"], present_current_observation], axis=-1 + ) + return observation + + +class RewardWrapper(gym.Wrapper): + def __init__(self, env, reward_classifier, device="cuda"): + """ + Wrapper to add reward prediction to the environment using a trained classifier. + + Args: + env: The environment to wrap. + reward_classifier: The reward classifier model. + device: The device to run the model on. + """ + self.env = env + + self.device = device + + self.reward_classifier = torch.compile(reward_classifier) + self.reward_classifier.to(self.device) + + def step(self, action): + """ + Execute a step and compute the reward using the classifier. + + Args: + action: The action to take in the environment. + + Returns: + Tuple of (observation, reward, terminated, truncated, info). + """ + observation, _, terminated, truncated, info = self.env.step(action) + + images = {} + for key in observation: + if "image" in key: + images[key] = observation[key].to(self.device, non_blocking=(self.device == "cuda")) + if images[key].dim() == 3: + images[key] = images[key].unsqueeze(0) + + start_time = time.perf_counter() + with torch.inference_mode(): + success = ( + self.reward_classifier.predict_reward(images, threshold=0.7) + if self.reward_classifier is not None + else 0.0 + ) + info["Reward classifier frequency"] = 1 / (time.perf_counter() - start_time) + + reward = 0.0 + if success == 1.0: + terminated = True + reward = 1.0 + + return observation, reward, terminated, truncated, info + + def reset(self, seed=None, options=None): + """ + Reset the environment. + + Args: + seed: Random seed for reproducibility. + options: Additional reset options. + + Returns: + The initial observation and info from the wrapped environment. + """ + return self.env.reset(seed=seed, options=options) + + +class TimeLimitWrapper(gym.Wrapper): + """ + Wrapper that adds a time limit to episodes and tracks execution time. + + This wrapper terminates episodes after a specified time has elapsed, providing + better control over episode length. + """ + + def __init__(self, env, control_time_s, fps): + """ + Initialize the time limit wrapper. + + Args: + env: The environment to wrap. + control_time_s: Maximum episode duration in seconds. + fps: Frames per second for calculating the maximum number of steps. + """ + self.env = env + self.control_time_s = control_time_s + self.fps = fps + + self.last_timestamp = 0.0 + self.episode_time_in_s = 0.0 + + self.max_episode_steps = int(self.control_time_s * self.fps) + + self.current_step = 0 + + def step(self, action): + """ + Step the environment and track time elapsed. + + Args: + action: The action to take in the environment. + + Returns: + Tuple of (observation, reward, terminated, truncated, info). + """ + obs, reward, terminated, truncated, info = self.env.step(action) + time_since_last_step = time.perf_counter() - self.last_timestamp + self.episode_time_in_s += time_since_last_step + self.last_timestamp = time.perf_counter() + self.current_step += 1 + # check if last timestep took more time than the expected fps + if 1.0 / time_since_last_step < self.fps: + logging.debug(f"Current timestep exceeded expected fps {self.fps}") + + if self.current_step >= self.max_episode_steps: + terminated = True + return obs, reward, terminated, truncated, info + + def reset(self, seed=None, options=None): + """ + Reset the environment and time tracking. + + Args: + seed: Random seed for reproducibility. + options: Additional reset options. + + Returns: + The initial observation and info from the wrapped environment. + """ + self.episode_time_in_s = 0.0 + self.last_timestamp = time.perf_counter() + self.current_step = 0 + return self.env.reset(seed=seed, options=options) + + +class ImageCropResizeWrapper(gym.Wrapper): + """ + Wrapper that crops and resizes image observations. + + This wrapper processes image observations to focus on relevant regions by + cropping and then resizing to a standard size. + """ + + def __init__( + self, + env, + crop_params_dict: dict[str, Annotated[tuple[int], 4]], + resize_size=None, + ): + """ + Initialize the image crop and resize wrapper. + + Args: + env: The environment to wrap. + crop_params_dict: Dictionary mapping image observation keys to crop parameters + (top, left, height, width). + resize_size: Target size for resized images (height, width). Defaults to (128, 128). + """ + super().__init__(env) + self.env = env + self.crop_params_dict = crop_params_dict + print(f"obs_keys , {self.env.observation_space}") + print(f"crop params dict {crop_params_dict.keys()}") + for key_crop in crop_params_dict: + if key_crop not in self.env.observation_space.keys(): # noqa: SIM118 + raise ValueError(f"Key {key_crop} not in observation space") + for key in crop_params_dict: + new_shape = (3, resize_size[0], resize_size[1]) + self.observation_space[key] = gym.spaces.Box(low=0, high=255, shape=new_shape) + + self.resize_size = resize_size + if self.resize_size is None: + self.resize_size = (128, 128) + + def step(self, action): + """ + Step the environment and process image observations. + + Args: + action: The action to take in the environment. + + Returns: + Tuple of (observation, reward, terminated, truncated, info) with processed images. + """ + obs, reward, terminated, truncated, info = self.env.step(action) + for k in self.crop_params_dict: + device = obs[k].device + if obs[k].dim() >= 3: + # Reshape to combine height and width dimensions for easier calculation + batch_size = obs[k].size(0) + channels = obs[k].size(1) + flattened_spatial_dims = obs[k].view(batch_size, channels, -1) + + # Calculate standard deviation across spatial dimensions (H, W) + # If any channel has std=0, all pixels in that channel have the same value + # This is helpful if one camera mistakenly covered or the image is black + std_per_channel = torch.std(flattened_spatial_dims, dim=2) + if (std_per_channel <= 0.02).any(): + logging.warning( + f"Potential hardware issue detected: All pixels have the same value in observation {k}" + ) + + if device == torch.device("mps:0"): + obs[k] = obs[k].cpu() + + obs[k] = F.crop(obs[k], *self.crop_params_dict[k]) + obs[k] = F.resize(obs[k], self.resize_size) + # TODO (michel-aractingi): Bug in resize, it returns values outside [0, 1] + obs[k] = obs[k].clamp(0.0, 1.0) + obs[k] = obs[k].to(device) + + return obs, reward, terminated, truncated, info + + def reset(self, seed=None, options=None): + """ + Reset the environment and process image observations. + + Args: + seed: Random seed for reproducibility. + options: Additional reset options. + + Returns: + Tuple of (observation, info) with processed images. + """ + obs, info = self.env.reset(seed=seed, options=options) + for k in self.crop_params_dict: + device = obs[k].device + if device == torch.device("mps:0"): + obs[k] = obs[k].cpu() + obs[k] = F.crop(obs[k], *self.crop_params_dict[k]) + obs[k] = F.resize(obs[k], self.resize_size) + obs[k] = obs[k].clamp(0.0, 1.0) + obs[k] = obs[k].to(device) + return obs, info + + +class ConvertToLeRobotObservation(gym.ObservationWrapper): + """ + Wrapper that converts standard observations to LeRobot format. + + This wrapper processes observations to match the expected format for LeRobot, + including normalizing image values and moving tensors to the specified device. + """ + + def __init__(self, env, device: str = "cpu"): + """ + Initialize the LeRobot observation converter. + + Args: + env: The environment to wrap. + device: Target device for the observation tensors. + """ + super().__init__(env) + + self.device = torch.device(device) + + def observation(self, observation): + """ + Convert observations to LeRobot format. + + Args: + observation: The original observation from the environment. + + Returns: + The processed observation with normalized images and proper tensor formats. + """ + observation = preprocess_observation(observation) + observation = { + key: observation[key].to(self.device, non_blocking=self.device.type == "cuda") + for key in observation + } + return observation + + +class ResetWrapper(gym.Wrapper): + """ + Wrapper that handles environment reset procedures. + + This wrapper provides additional functionality during environment reset, + including the option to reset to a fixed pose or allow manual reset. + """ + + def __init__( + self, + env: RobotEnv, + reset_pose: np.ndarray | None = None, + reset_time_s: float = 5, + ): + """ + Initialize the reset wrapper. + + Args: + env: The environment to wrap. + reset_pose: Fixed joint positions to reset to. If None, manual reset is used. + reset_time_s: Time in seconds to wait after reset or allowed for manual reset. + """ + super().__init__(env) + self.reset_time_s = reset_time_s + self.reset_pose = reset_pose + self.robot = self.unwrapped.robot + + def reset(self, *, seed=None, options=None): + """ + Reset the environment with either fixed or manual reset procedure. + + If reset_pose is provided, the robot will move to that position. + Otherwise, manual teleoperation control is allowed for reset_time_s seconds. + + Args: + seed: Random seed for reproducibility. + options: Additional reset options. + + Returns: + The initial observation and info from the wrapped environment. + """ + start_time = time.perf_counter() + if self.reset_pose is not None: + log_say("Reset the environment.", play_sounds=True) + reset_follower_position(self.unwrapped.robot, self.reset_pose) + log_say("Reset the environment done.", play_sounds=True) + + if hasattr(self.env, "robot_leader"): + self.env.robot_leader.bus.sync_write("Torque_Enable", 1) + log_say("Reset the leader robot.", play_sounds=True) + reset_follower_position(self.env.robot_leader, self.reset_pose) + log_say("Reset the leader robot done.", play_sounds=True) + else: + log_say( + f"Manually reset the environment for {self.reset_time_s} seconds.", + play_sounds=True, + ) + start_time = time.perf_counter() + while time.perf_counter() - start_time < self.reset_time_s: + action = self.env.robot_leader.get_action() + self.unwrapped.robot.send_action(action) + + log_say("Manual reset of the environment done.", play_sounds=True) + + busy_wait(self.reset_time_s - (time.perf_counter() - start_time)) + + return super().reset(seed=seed, options=options) + + +class BatchCompatibleWrapper(gym.ObservationWrapper): + """ + Wrapper that ensures observations are compatible with batch processing. + + This wrapper adds a batch dimension to observations that don't already have one, + making them compatible with models that expect batched inputs. + """ + + def __init__(self, env): + """ + Initialize the batch compatibility wrapper. + + Args: + env: The environment to wrap. + """ + super().__init__(env) + + def observation(self, observation: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Add batch dimensions to observations if needed. + + Args: + observation: Dictionary of observation tensors. + + Returns: + Dictionary of observation tensors with batch dimensions. + """ + for key in observation: + if "image" in key and observation[key].dim() == 3: + observation[key] = observation[key].unsqueeze(0) + if "state" in key and observation[key].dim() == 1: + observation[key] = observation[key].unsqueeze(0) + if "velocity" in key and observation[key].dim() == 1: + observation[key] = observation[key].unsqueeze(0) + return observation + + +class GripperPenaltyWrapper(gym.RewardWrapper): + """ + Wrapper that adds penalties for inefficient gripper commands. + + This wrapper modifies rewards to discourage excessive gripper movement + or commands that attempt to move the gripper beyond its physical limits. + """ + + def __init__(self, env, penalty: float = -0.1): + """ + Initialize the gripper penalty wrapper. + + Args: + env: The environment to wrap. + penalty: Negative reward value to apply for inefficient gripper actions. + """ + super().__init__(env) + self.penalty = penalty + self.last_gripper_state = None + + def reward(self, reward, action): + """ + Apply penalties to reward based on gripper actions. + + Args: + reward: The original reward from the environment. + action: The action that was taken. + + Returns: + Modified reward with penalty applied if necessary. + """ + gripper_state_normalized = self.last_gripper_state / self.unwrapped.robot.config.max_gripper_pos + + action_normalized = action - 1.0 # action / MAX_GRIPPER_COMMAND + + gripper_penalty_bool = (gripper_state_normalized < 0.5 and action_normalized > 0.5) or ( + gripper_state_normalized > 0.75 and action_normalized < -0.5 + ) + + return reward + self.penalty * int(gripper_penalty_bool) + + def step(self, action): + """ + Step the environment and apply gripper penalties. + + Args: + action: The action to take in the environment. + + Returns: + Tuple of (observation, reward, terminated, truncated, info) with penalty applied. + """ + self.last_gripper_state = self.unwrapped.robot.bus.sync_read("Present_Position")["gripper"] + + gripper_action = action[-1] + obs, reward, terminated, truncated, info = self.env.step(action) + gripper_penalty = self.reward(reward, gripper_action) + + info["discrete_penalty"] = gripper_penalty + + return obs, reward, terminated, truncated, info + + def reset(self, **kwargs): + """ + Reset the environment and penalty tracking. + + Args: + **kwargs: Keyword arguments passed to the wrapped environment's reset. + + Returns: + The initial observation and info with gripper penalty initialized. + """ + self.last_gripper_state = None + obs, info = super().reset(**kwargs) + info["gripper_penalty"] = 0.0 + return obs, info + + +class GripperActionWrapper(gym.ActionWrapper): + """ + Wrapper that processes gripper control commands. + + This wrapper quantizes and processes gripper commands, adding a sleep time between + consecutive gripper actions to prevent rapid toggling. + """ + + def __init__(self, env, quantization_threshold: float = 0.2, gripper_sleep: float = 0.0): + """ + Initialize the gripper action wrapper. + + Args: + env: The environment to wrap. + quantization_threshold: Threshold below which gripper commands are quantized to zero. + gripper_sleep: Minimum time in seconds between consecutive gripper commands. + """ + super().__init__(env) + self.quantization_threshold = quantization_threshold + self.gripper_sleep = gripper_sleep + self.last_gripper_action_time = 0.0 + self.last_gripper_action = None + + def action(self, action): + """ + Process gripper commands in the action. + + Args: + action: The original action from the agent. + + Returns: + Modified action with processed gripper command. + """ + if self.gripper_sleep > 0.0: + if ( + self.last_gripper_action is not None + and time.perf_counter() - self.last_gripper_action_time < self.gripper_sleep + ): + action[-1] = self.last_gripper_action + else: + self.last_gripper_action_time = time.perf_counter() + self.last_gripper_action = action[-1] + + gripper_command = action[-1] + # Gripper actions are between 0, 2 + # we want to quantize them to -1, 0 or 1 + gripper_command = gripper_command - 1.0 + + if self.quantization_threshold is not None: + # Quantize gripper command to -1, 0 or 1 + gripper_command = ( + np.sign(gripper_command) if abs(gripper_command) > self.quantization_threshold else 0.0 + ) + gripper_command = gripper_command * self.unwrapped.robot.config.max_gripper_pos + + gripper_state = self.unwrapped.robot.bus.sync_read("Present_Position")["gripper"] + + gripper_action_value = np.clip( + gripper_state + gripper_command, 0, self.unwrapped.robot.config.max_gripper_pos + ) + action[-1] = gripper_action_value.item() + return action + + def reset(self, **kwargs): + """ + Reset the gripper action tracking. + + Args: + **kwargs: Keyword arguments passed to the wrapped environment's reset. + + Returns: + The initial observation and info. + """ + obs, info = super().reset(**kwargs) + self.last_gripper_action_time = 0.0 + self.last_gripper_action = None + return obs, info + + +class EEObservationWrapper(gym.ObservationWrapper): + """ + Wrapper that adds end-effector pose information to observations. + + This wrapper computes the end-effector pose using forward kinematics + and adds it to the observation space. + """ + + def __init__(self, env, ee_pose_limits): + """ + Initialize the end-effector observation wrapper. + + Args: + env: The environment to wrap. + ee_pose_limits: Dictionary with 'min' and 'max' keys containing limits for EE pose. + """ + super().__init__(env) + + # Extend observation space to include end effector pose + prev_space = self.observation_space["observation.state"] + + self.observation_space["observation.state"] = gym.spaces.Box( + low=np.concatenate([prev_space.low, ee_pose_limits["min"]]), + high=np.concatenate([prev_space.high, ee_pose_limits["max"]]), + shape=(prev_space.shape[0] + 3,), + dtype=np.float32, + ) + + # Initialize kinematics instance for the appropriate robot type + robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so101") + if "so100" in robot_type or "so101" in robot_type: + # Note to be compatible with the rest of the codebase, + # we are using the new calibration method for so101 and so100 + robot_type = "so_new_calibration" + self.kinematics = RobotKinematics(robot_type) + + def observation(self, observation): + """ + Add end-effector pose to the observation. + + Args: + observation: Original observation from the environment. + + Returns: + Enhanced observation with end-effector pose information. + """ + current_joint_pos = self.unwrapped._get_observation()["agent_pos"] + + current_ee_pos = self.kinematics.forward_kinematics(current_joint_pos, frame="gripper_tip")[:3, 3] + observation["agent_pos"] = np.concatenate([observation["agent_pos"], current_ee_pos], -1) + return observation + + +########################################################### +# Wrappers related to human intervention and input devices +########################################################### + + +class BaseLeaderControlWrapper(gym.Wrapper): + """ + Base class for leader-follower robot control wrappers. + + This wrapper enables human intervention through a leader-follower robot setup, + where the human can control a leader robot to guide the follower robot's movements. + """ + + def __init__( + self, + env, + teleop_device, + end_effector_step_sizes, + use_geared_leader_arm: bool = False, + use_gripper=False, + ): + """ + Initialize the base leader control wrapper. + + Args: + env: The environment to wrap. + teleop_device: The teleoperation device. + use_geared_leader_arm: Whether to use a geared leader arm setup. + use_gripper: Whether to include gripper control. + """ + super().__init__(env) + self.robot_leader = teleop_device + self.robot_follower = env.unwrapped.robot + self.use_geared_leader_arm = use_geared_leader_arm + self.use_gripper: bool = use_gripper + self.end_effector_step_sizes = np.array(list(end_effector_step_sizes.values())) + + # Set up keyboard event tracking + self._init_keyboard_events() + self.event_lock = Lock() # Thread-safe access to events + + # Initialize robot control + robot_type = getattr(env.unwrapped.robot.config, "robot_type", "so101") + if "so100" in robot_type or "so101" in robot_type: + # Note to be compatible with the rest of the codebase, + # we are using the new calibration method for so101 and so100 + robot_type = "so_new_calibration" + self.kinematics = RobotKinematics(robot_type) + self.leader_torque_enabled = True + self.prev_leader_gripper = None + + # Configure leader arm + # NOTE: Lower the gains of leader arm for automatic take-over + # With lower gains we can manually move the leader arm without risk of injury to ourselves or the robot + # With higher gains, it would be dangerous and difficult to modify the leader's pose while torque is enabled + # Default value for P_coeff is 32 + self.robot_leader.bus.sync_write("Torque_Enable", 1) + for motor in self.robot_leader.bus.motors: + self.robot_leader.bus.write("P_Coefficient", motor, 16) + self.robot_leader.bus.write("I_Coefficient", motor, 0) + self.robot_leader.bus.write("D_Coefficient", motor, 16) + + self.leader_tracking_error_queue = deque(maxlen=4) + self._init_keyboard_listener() + + def _init_keyboard_events(self): + """ + Initialize the keyboard events dictionary. + + This method sets up tracking for keyboard events used for intervention control. + It should be overridden in subclasses to add additional events. + """ + self.keyboard_events = { + "episode_success": False, + "episode_end": False, + "rerecord_episode": False, + } + + def _handle_key_press(self, key, keyboard_device): + """ + Handle key press events. + + Args: + key: The key that was pressed. + keyboard: The keyboard module with key definitions. + + This method should be overridden in subclasses for additional key handling. + """ + try: + if key == keyboard_device.Key.esc: + self.keyboard_events["episode_end"] = True + return + if key == keyboard_device.Key.left: + self.keyboard_events["rerecord_episode"] = True + return + if hasattr(key, "char") and key.char == "s": + logging.info("Key 's' pressed. Episode success triggered.") + self.keyboard_events["episode_success"] = True + return + except Exception as e: + logging.error(f"Error handling key press: {e}") + + def _init_keyboard_listener(self): + """ + Initialize the keyboard listener for intervention control. + + This method sets up keyboard event handling if not in headless mode. + """ + from pynput import keyboard as keyboard_device + + def on_press(key): + with self.event_lock: + self._handle_key_press(key, keyboard_device) + + self.listener = keyboard_device.Listener(on_press=on_press) + self.listener.start() + + def _check_intervention(self): + """ + Check if human intervention is needed. + + Returns: + Boolean indicating whether intervention is needed. + + This method should be overridden in subclasses with specific intervention logic. + """ + return False + + def _handle_intervention(self, action): + """ + Process actions during intervention mode. + + Args: + action: The original action from the agent. + + Returns: + Tuple of (modified_action, intervention_action). + """ + if self.leader_torque_enabled: + self.robot_leader.bus.sync_write("Torque_Enable", 0) + self.leader_torque_enabled = False + + leader_pos_dict = self.robot_leader.bus.sync_read("Present_Position") + follower_pos_dict = self.robot_follower.bus.sync_read("Present_Position") + + leader_pos = np.array([leader_pos_dict[name] for name in leader_pos_dict], dtype=np.float32) + follower_pos = np.array([follower_pos_dict[name] for name in follower_pos_dict], dtype=np.float32) + + self.leader_tracking_error_queue.append(np.linalg.norm(follower_pos[:-1] - leader_pos[:-1])) + + # [:3, 3] Last column of the transformation matrix corresponds to the xyz translation + leader_ee = self.kinematics.forward_kinematics(leader_pos, frame="gripper_tip")[:3, 3] + follower_ee = self.kinematics.forward_kinematics(follower_pos, frame="gripper_tip")[:3, 3] + + action = np.clip(leader_ee - follower_ee, -self.end_effector_step_sizes, self.end_effector_step_sizes) + # Normalize the action to the range [-1, 1] + action = action / self.end_effector_step_sizes + + if self.use_gripper: + if self.prev_leader_gripper is None: + self.prev_leader_gripper = np.clip( + leader_pos[-1], 0, self.robot_follower.config.max_gripper_pos + ) + + # Get gripper action delta based on leader pose + leader_gripper = leader_pos[-1] + gripper_delta = leader_gripper - self.prev_leader_gripper + + # Normalize by max angle and quantize to {0,1,2} + normalized_delta = gripper_delta / self.robot_follower.config.max_gripper_pos + if normalized_delta >= 0.3: + gripper_action = 2 + elif normalized_delta <= 0.1: + gripper_action = 0 + else: + gripper_action = 1 + + action = np.append(action, gripper_action) + + return action + + def _handle_leader_teleoperation(self): + """ + Handle leader teleoperation in non-intervention mode. + + This method synchronizes the leader robot position with the follower. + """ + + prev_leader_pos_dict = self.robot_leader.bus.sync_read("Present_Position") + prev_leader_pos = np.array( + [prev_leader_pos_dict[name] for name in prev_leader_pos_dict], dtype=np.float32 + ) + + if not self.leader_torque_enabled: + self.robot_leader.bus.sync_write("Torque_Enable", 1) + self.leader_torque_enabled = True + + follower_pos_dict = self.robot_follower.bus.sync_read("Present_Position") + follower_pos = np.array([follower_pos_dict[name] for name in follower_pos_dict], dtype=np.float32) + + goal_pos = {f"{motor}": follower_pos[i] for i, motor in enumerate(self.robot_leader.bus.motors)} + self.robot_leader.bus.sync_write("Goal_Position", goal_pos) + + self.leader_tracking_error_queue.append(np.linalg.norm(follower_pos[:-1] - prev_leader_pos[:-1])) + + def step(self, action): + """ + Execute a step with possible human intervention. + + Args: + action: The action to take in the environment. + + Returns: + Tuple of (observation, reward, terminated, truncated, info). + """ + is_intervention = self._check_intervention() + + # NOTE: + if is_intervention: + action = self._handle_intervention(action) + else: + self._handle_leader_teleoperation() + + # NOTE: + obs, reward, terminated, truncated, info = self.env.step(action) + + # Add intervention info + info["is_intervention"] = is_intervention + info["action_intervention"] = action + + self.prev_leader_gripper = np.clip( + self.robot_leader.bus.sync_read("Present_Position")["gripper"], + 0, + self.robot_follower.config.max_gripper_pos, + ) + + # Check for success or manual termination + success = self.keyboard_events["episode_success"] + terminated = terminated or self.keyboard_events["episode_end"] or success + + if success: + reward = 1.0 + logging.info("Episode ended successfully with reward 1.0") + + return obs, reward, terminated, truncated, info + + def reset(self, **kwargs): + """ + Reset the environment and intervention state. + + Args: + **kwargs: Keyword arguments passed to the wrapped environment's reset. + + Returns: + The initial observation and info. + """ + self.keyboard_events = dict.fromkeys(self.keyboard_events, False) + self.leader_tracking_error_queue.clear() + return super().reset(**kwargs) + + def close(self): + """ + Clean up resources, including stopping keyboard listener. + + Returns: + Result of closing the wrapped environment. + """ + if hasattr(self, "listener") and self.listener is not None: + self.listener.stop() + return self.env.close() + + +class GearedLeaderControlWrapper(BaseLeaderControlWrapper): + """ + Wrapper that enables manual intervention via keyboard. + + This wrapper extends the BaseLeaderControlWrapper to allow explicit toggling + of human intervention mode with keyboard controls. + """ + + def _init_keyboard_events(self): + """ + Initialize keyboard events including human intervention flag. + + Extends the base class dictionary with an additional flag for tracking + intervention state toggled by keyboard. + """ + super()._init_keyboard_events() + self.keyboard_events["human_intervention_step"] = False + + def _handle_key_press(self, key, keyboard_device): + """ + Handle key presses including space for intervention toggle. + + Args: + key: The key that was pressed. + keyboard: The keyboard module with key definitions. + + Extends the base handler to respond to space key for toggling intervention. + """ + super()._handle_key_press(key, keyboard_device) + if key == keyboard_device.Key.space: + if not self.keyboard_events["human_intervention_step"]: + logging.info( + "Space key pressed. Human intervention required.\n" + "Place the leader in similar pose to the follower and press space again." + ) + self.keyboard_events["human_intervention_step"] = True + log_say("Human intervention step.", play_sounds=True) + else: + self.keyboard_events["human_intervention_step"] = False + logging.info("Space key pressed for a second time.\nContinuing with policy actions.") + log_say("Continuing with policy actions.", play_sounds=True) + + def _check_intervention(self): + """ + Check if human intervention is active based on keyboard toggle. + + Returns: + Boolean indicating whether intervention mode is active. + """ + return self.keyboard_events["human_intervention_step"] + + +class GearedLeaderAutomaticControlWrapper(BaseLeaderControlWrapper): + """ + Wrapper with automatic intervention based on error thresholds. + + This wrapper monitors the error between leader and follower positions + and automatically triggers intervention when error exceeds thresholds. + """ + + def __init__( + self, + env, + teleop_device, + end_effector_step_sizes, + use_gripper=False, + intervention_threshold=10.0, + release_threshold=1e-2, + ): + """ + Initialize the automatic intervention wrapper. + + Args: + env: The environment to wrap. + teleop_device: The teleoperation device. + use_gripper: Whether to include gripper control. + intervention_threshold: Error threshold to trigger intervention. + release_threshold: Error threshold to release intervention. + queue_size: Number of error measurements to track for smoothing. + """ + super().__init__(env, teleop_device, end_effector_step_sizes, use_gripper=use_gripper) + + # Error tracking parameters + self.intervention_threshold = intervention_threshold # Threshold to trigger intervention + self.release_threshold = release_threshold # Threshold to release intervention + self.is_intervention_active = False + self.start_time = time.perf_counter() + + def _check_intervention(self): + """ + Determine if intervention should occur based on the rate of change of leader-follower error in end_effector space. + + This method monitors the rate of change of leader-follower error in end_effector space + and automatically triggers intervention when the rate of change exceeds + the intervention threshold, releasing when it falls below the release threshold. + + Returns: + Boolean indicating whether intervention should be active. + """ + + # Condition for starting the intervention + # If the error in teleoperation is too high, that means the a user has grasped the leader robot and he wants to take over + if ( + not self.is_intervention_active + and len(self.leader_tracking_error_queue) == self.leader_tracking_error_queue.maxlen + and np.var(list(self.leader_tracking_error_queue)[-2:]) > self.intervention_threshold + ): + self.is_intervention_active = True + self.leader_tracking_error_queue.clear() + log_say("Intervention started", play_sounds=True) + return True + + # Track the error over time in leader_tracking_error_queue + # If the variance of the tracking error is too low, that means the user has let go of the leader robot and the intervention is over + if ( + self.is_intervention_active + and len(self.leader_tracking_error_queue) == self.leader_tracking_error_queue.maxlen + and np.var(self.leader_tracking_error_queue) < self.release_threshold + ): + self.is_intervention_active = False + self.leader_tracking_error_queue.clear() + log_say("Intervention ended", play_sounds=True) + return False + + # If not change has happened that merits a change in the intervention state, return the current state + return self.is_intervention_active + + def reset(self, **kwargs): + """ + Reset error tracking on environment reset. + + Args: + **kwargs: Keyword arguments passed to the wrapped environment's reset. + + Returns: + The initial observation and info. + """ + self.is_intervention_active = False + return super().reset(**kwargs) + + +class GamepadControlWrapper(gym.Wrapper): + """ + Wrapper that allows controlling a gym environment with a gamepad. + + This wrapper intercepts the step method and allows human input via gamepad + to override the agent's actions when desired. + """ + + def __init__( + self, + env, + teleop_device, # Accepts an instantiated teleoperator + use_gripper=False, # This should align with teleop_device's config + auto_reset=False, + ): + """ + Initialize the gamepad controller wrapper. + + Args: + env: The environment to wrap. + teleop_device: The instantiated teleoperation device (e.g., GamepadTeleop). + use_gripper: Whether to include gripper control (should match teleop_device.config.use_gripper). + auto_reset: Whether to auto reset the environment when episode ends. + """ + super().__init__(env) + + self.teleop_device = teleop_device + # Ensure the teleop_device is connected if it has a connect method + if hasattr(self.teleop_device, "connect") and not self.teleop_device.is_connected: + self.teleop_device.connect() + + # self.controller attribute is removed + + self.auto_reset = auto_reset + # use_gripper from args should ideally match teleop_device.config.use_gripper + # For now, we use the one passed, but it can lead to inconsistency if not set correctly from config + self.use_gripper = use_gripper + + logging.info("Gamepad control wrapper initialized with provided teleop_device.") + print( + "Gamepad controls (managed by the provided teleop_device - specific button mappings might vary):" + ) + print(" Left analog stick: Move in X-Y plane") + print(" Right analog stick: Move in Z axis (up/down)") + print(" X/Square button: End episode (FAILURE)") + print(" Y/Triangle button: End episode (SUCCESS)") + print(" B/Circle button: Exit program") + + def get_teleop_commands( + self, + ) -> tuple[bool, np.ndarray, bool, bool, bool]: + """ + Get the current action from the gamepad if any input is active. + + Returns: + Tuple containing: + - is_active: Whether gamepad input is active (from teleop_device.gamepad.should_intervene()) + - action: The action derived from gamepad input (from teleop_device.get_action()) + - terminate_episode: Whether episode termination was requested + - success: Whether episode success was signaled + - rerecord_episode: Whether episode rerecording was requested + """ + if not hasattr(self.teleop_device, "gamepad") or self.teleop_device.gamepad is None: + raise AttributeError( + "teleop_device does not have a 'gamepad' attribute or it is None. Expected for GamepadControlWrapper." + ) + + # Get status flags from the underlying gamepad controller within the teleop_device + self.teleop_device.gamepad.update() # Ensure gamepad state is fresh + intervention_is_active = self.teleop_device.gamepad.should_intervene() + episode_end_status = self.teleop_device.gamepad.get_episode_end_status() + + terminate_episode = episode_end_status is not None + success = episode_end_status == "success" + rerecord_episode = episode_end_status == "rerecord_episode" + + # Get the action dictionary from the teleop_device + action_dict = self.teleop_device.get_action() + + # Convert action_dict to numpy array based on expected structure + # Order: delta_x, delta_y, delta_z, gripper (if use_gripper) + action_list = [action_dict["delta_x"], action_dict["delta_y"], action_dict["delta_z"]] + if self.use_gripper: + # GamepadTeleop returns gripper action as 0 (close), 1 (stay), 2 (open) + # This needs to be consistent with what EEActionWrapper expects if it's used downstream + # EEActionWrapper for gripper typically expects 0.0 (closed) to 2.0 (open) + # For now, we pass the direct value from GamepadTeleop, ensure downstream compatibility. + gripper_val = action_dict.get("gripper", 1.0) # Default to 1.0 (stay) if not present + action_list.append(float(gripper_val)) + + gamepad_action_np = np.array(action_list, dtype=np.float32) + + return ( + intervention_is_active, + gamepad_action_np, + terminate_episode, + success, + rerecord_episode, + ) + + def step(self, action): + """ + Step the environment, using gamepad input to override actions when active. + + Args: + action: Original action from agent. + + Returns: + Tuple of (observation, reward, terminated, truncated, info). + """ + # Get gamepad state and action + ( + is_intervention, + gamepad_action, + terminate_episode, + success, + rerecord_episode, + ) = self.get_teleop_commands() + + # Update episode ending state if requested + if terminate_episode: + logging.info(f"Episode manually ended: {'SUCCESS' if success else 'FAILURE'}") + + # Only override the action if gamepad is active + action = gamepad_action if is_intervention else action + + # Step the environment + obs, reward, terminated, truncated, info = self.env.step(action) + + # Add episode ending if requested via gamepad + terminated = terminated or truncated or terminate_episode + + if success: + reward = 1.0 + logging.info("Episode ended successfully with reward 1.0") + + if isinstance(action, np.ndarray): + action = torch.from_numpy(action) + + info["is_intervention"] = is_intervention + # The original `BaseLeaderControlWrapper` puts `action_intervention` in info. + # For Gamepad, if intervention, `gamepad_action` is the intervention. + # If not intervention, policy's action is `action`. + # For consistency, let's store the *human's* action if intervention occurred. + info["action_intervention"] = action + + info["rerecord_episode"] = rerecord_episode + + # If episode ended, reset the state + if terminated or truncated: + # Add success/failure information to info dict + info["next.success"] = success + + # Auto reset if configured + if self.auto_reset: + obs, reset_info = self.reset() + info.update(reset_info) + + return obs, reward, terminated, truncated, info + + def close(self): + """ + Clean up resources when environment closes. + + Returns: + Result of closing the wrapped environment. + """ + if hasattr(self.teleop_device, "disconnect"): + self.teleop_device.disconnect() + + # Call the parent close method + return self.env.close() + + +class KeyboardControlWrapper(GamepadControlWrapper): + """ + Wrapper that allows controlling a gym environment with a keyboard. + + This wrapper intercepts the step method and allows human input via keyboard + to override the agent's actions when desired. + + Inherits from GamepadControlWrapper to avoid code duplication. + """ + + def __init__( + self, + env, + teleop_device, # Accepts an instantiated teleoperator + use_gripper=False, # This should align with teleop_device's config + auto_reset=False, + ): + """ + Initialize the gamepad controller wrapper. + + Args: + env: The environment to wrap. + teleop_device: The instantiated teleoperation device (e.g., GamepadTeleop). + use_gripper: Whether to include gripper control (should match teleop_device.config.use_gripper). + auto_reset: Whether to auto reset the environment when episode ends. + """ + super().__init__(env, teleop_device, use_gripper, auto_reset) + + self.is_intervention_active = False + + logging.info("Keyboard control wrapper initialized with provided teleop_device.") + print("Keyboard controls:") + print(" Arrow keys: Move in X-Y plane") + print(" Shift and Shift_R: Move in Z axis") + print(" Right Ctrl and Left Ctrl: Open and close gripper") + print(" f: End episode with FAILURE") + print(" s: End episode with SUCCESS") + print(" r: End episode with RERECORD") + print(" i: Start/Stop Intervention") + + def get_teleop_commands( + self, + ) -> tuple[bool, np.ndarray, bool, bool, bool]: + action_dict = self.teleop_device.get_action() + episode_end_status = None + + # Unroll the misc_keys_queue to check for events related to intervention, episode success, etc. + while not self.teleop_device.misc_keys_queue.empty(): + key = self.teleop_device.misc_keys_queue.get() + if key == "i": + self.is_intervention_active = not self.is_intervention_active + elif key == "f": + episode_end_status = "failure" + elif key == "s": + episode_end_status = "success" + elif key == "r": + episode_end_status = "rerecord_episode" + + terminate_episode = episode_end_status is not None + success = episode_end_status == "success" + rerecord_episode = episode_end_status == "rerecord_episode" + + # Convert action_dict to numpy array based on expected structure + # Order: delta_x, delta_y, delta_z, gripper (if use_gripper) + action_list = [action_dict["delta_x"], action_dict["delta_y"], action_dict["delta_z"]] + if self.use_gripper: + # GamepadTeleop returns gripper action as 0 (close), 1 (stay), 2 (open) + # This needs to be consistent with what EEActionWrapper expects if it's used downstream + # EEActionWrapper for gripper typically expects 0.0 (closed) to 2.0 (open) + # For now, we pass the direct value from GamepadTeleop, ensure downstream compatibility. + gripper_val = action_dict.get("gripper", 1.0) # Default to 1.0 (stay) if not present + action_list.append(float(gripper_val)) + + gamepad_action_np = np.array(action_list, dtype=np.float32) + + return ( + self.is_intervention_active, + gamepad_action_np, + terminate_episode, + success, + rerecord_episode, + ) + + +class GymHilDeviceWrapper(gym.Wrapper): + def __init__(self, env, device="cpu"): + super().__init__(env) + self.device = device + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + for k in obs: + obs[k] = obs[k].to(self.device) + if "action_intervention" in info: + # NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device + info["action_intervention"] = info["action_intervention"].astype(np.float32) + info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device) + return obs, reward, terminated, truncated, info + + def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None): + obs, info = self.env.reset(seed=seed, options=options) + for k in obs: + obs[k] = obs[k].to(self.device) + if "action_intervention" in info: + # NOTE: This is a hack to ensure the action intervention is a float32 tensor and supported on MPS device + info["action_intervention"] = info["action_intervention"].astype(np.float32) + info["action_intervention"] = torch.from_numpy(info["action_intervention"]).to(self.device) + return obs, info + + +class GymHilObservationProcessorWrapper(gym.ObservationWrapper): + def __init__(self, env: gym.Env): + super().__init__(env) + prev_space = self.observation_space + new_space = {} + + for key in prev_space: + if "pixels" in key: + for k in prev_space["pixels"]: + new_space[f"observation.images.{k}"] = gym.spaces.Box( + 0.0, 255.0, shape=(3, 128, 128), dtype=np.uint8 + ) + + if key == "agent_pos": + new_space["observation.state"] = prev_space["agent_pos"] + + self.observation_space = gym.spaces.Dict(new_space) + + def observation(self, observation: dict[str, Any]) -> dict[str, Any]: + return preprocess_observation(observation) + + +########################################################### +# Factory functions +########################################################### + + +def make_robot_env(cfg: EnvConfig) -> gym.Env: + """ + Factory function to create a robot environment. + + This function builds a robot environment with all necessary wrappers + based on the provided configuration. + + Args: + cfg: Configuration object containing environment parameters. + + Returns: + A gym environment with all necessary wrappers applied. + """ + if cfg.type == "hil": + import gym_hil # noqa: F401 + + # TODO (azouitine) + env = gym.make( + f"gym_hil/{cfg.task}", + image_obs=True, + render_mode="human", + use_gripper=cfg.wrapper.use_gripper, + gripper_penalty=cfg.wrapper.gripper_penalty, + ) + env = GymHilObservationProcessorWrapper(env=env) + env = GymHilDeviceWrapper(env=env, device=cfg.device) + env = BatchCompatibleWrapper(env=env) + env = TorchActionWrapper(env=env, device=cfg.device) + return env + + if not hasattr(cfg, "robot") or not hasattr(cfg, "teleop"): + raise ValueError( + "Configuration for 'gym_manipulator' must be HILSerlRobotEnvConfig with robot and teleop." + ) + + if cfg.robot is None: + raise ValueError("RobotConfig (cfg.robot) must be provided for gym_manipulator environment.") + robot = make_robot_from_config(cfg.robot) + + teleop_device = make_teleoperator_from_config(cfg.teleop) + teleop_device.connect() + + # Create base environment + env = RobotEnv( + robot=robot, + use_gripper=cfg.wrapper.use_gripper, + display_cameras=cfg.wrapper.display_cameras if cfg.wrapper else False, + ) + + # Add observation and image processing + if cfg.wrapper: + if cfg.wrapper.add_joint_velocity_to_observation: + env = AddJointVelocityToObservation(env=env, fps=cfg.fps) + if cfg.wrapper.add_current_to_observation: + env = AddCurrentToObservation(env=env) + if cfg.wrapper.add_ee_pose_to_observation: + env = EEObservationWrapper(env=env, ee_pose_limits=robot.end_effector_bounds) + + env = ConvertToLeRobotObservation(env=env, device=cfg.device) + + if cfg.wrapper and cfg.wrapper.crop_params_dict is not None: + env = ImageCropResizeWrapper( + env=env, + crop_params_dict=cfg.wrapper.crop_params_dict, + resize_size=cfg.wrapper.resize_size, + ) + + # Add reward computation and control wrappers + reward_classifier = init_reward_classifier(cfg) + if reward_classifier is not None: + env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) + + env = TimeLimitWrapper(env=env, control_time_s=cfg.wrapper.control_time_s, fps=cfg.fps) + if cfg.wrapper.use_gripper and cfg.wrapper.gripper_penalty is not None: + env = GripperPenaltyWrapper( + env=env, + penalty=cfg.wrapper.gripper_penalty, + ) + + # Control mode specific wrappers + control_mode = cfg.wrapper.control_mode + if control_mode == "gamepad": + assert isinstance(teleop_device, GamepadTeleop), ( + "teleop_device must be an instance of GamepadTeleop for gamepad control mode" + ) + env = GamepadControlWrapper( + env=env, + teleop_device=teleop_device, + use_gripper=cfg.wrapper.use_gripper, + ) + elif control_mode == "keyboard_ee": + assert isinstance(teleop_device, KeyboardEndEffectorTeleop), ( + "teleop_device must be an instance of KeyboardEndEffectorTeleop for keyboard control mode" + ) + env = KeyboardControlWrapper( + env=env, + teleop_device=teleop_device, + use_gripper=cfg.wrapper.use_gripper, + ) + elif control_mode == "leader": + env = GearedLeaderControlWrapper( + env=env, + teleop_device=teleop_device, + end_effector_step_sizes=cfg.robot.end_effector_step_sizes, + use_gripper=cfg.wrapper.use_gripper, + ) + elif control_mode == "leader_automatic": + env = GearedLeaderAutomaticControlWrapper( + env=env, + teleop_device=teleop_device, + end_effector_step_sizes=cfg.robot.end_effector_step_sizes, + use_gripper=cfg.wrapper.use_gripper, + ) + else: + raise ValueError(f"Invalid control mode: {control_mode}") + + env = ResetWrapper( + env=env, + reset_pose=cfg.wrapper.fixed_reset_joint_positions, + reset_time_s=cfg.wrapper.reset_time_s, + ) + + env = BatchCompatibleWrapper(env=env) + env = TorchActionWrapper(env=env, device=cfg.device) + + return env + + +def init_reward_classifier(cfg): + """ + Load a reward classifier policy from a pretrained path if configured. + + Args: + cfg: The environment configuration containing classifier paths. + + Returns: + The loaded classifier model or None if not configured. + """ + if cfg.reward_classifier_pretrained_path is None: + return None + + from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier + + # Get device from config or default to CUDA + device = getattr(cfg, "device", "cpu") + + # Load the classifier directly using from_pretrained + classifier = Classifier.from_pretrained( + pretrained_name_or_path=cfg.reward_classifier_pretrained_path, + ) + + # Ensure model is on the correct device + classifier.to(device) + classifier.eval() # Set to evaluation mode + + return classifier + + +########################################################### +# Record and replay functions +########################################################### + + +def record_dataset(env, policy, cfg): + """ + Record a dataset of robot interactions using either a policy or teleop. + + This function runs episodes in the environment and records the observations, + actions, and results for dataset creation. + + Args: + env: The environment to record from. + policy: Optional policy to generate actions (if None, uses teleop). + cfg: Configuration object containing recording parameters like: + - repo_id: Repository ID for dataset storage + - dataset_root: Local root directory for dataset + - num_episodes: Number of episodes to record + - fps: Frames per second for recording + - push_to_hub: Whether to push dataset to Hugging Face Hub + - task: Name/description of the task being recorded + - number_of_steps_after_success: Number of additional steps to continue recording after + a success (reward=1) is detected. This helps collect + more positive examples for reward classifier training. + """ + from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + # Setup initial action (zero action if using teleop) + action = env.action_space.sample() * 0.0 + + action_names = ["delta_x_ee", "delta_y_ee", "delta_z_ee"] + if cfg.wrapper.use_gripper: + action_names.append("gripper_delta") + + # Configure dataset features based on environment spaces + features = { + "observation.state": { + "dtype": "float32", + "shape": env.observation_space["observation.state"].shape, + "names": None, + }, + "action": { + "dtype": "float32", + "shape": (len(action_names),), + "names": action_names, + }, + "next.reward": {"dtype": "float32", "shape": (1,), "names": None}, + "next.done": {"dtype": "bool", "shape": (1,), "names": None}, + "complementary_info.discrete_penalty": { + "dtype": "float32", + "shape": (1,), + "names": ["discrete_penalty"], + }, + } + + # Add image features + for key in env.observation_space: + if "image" in key: + features[key] = { + "dtype": "video", + "shape": env.observation_space[key].shape, + "names": ["channels", "height", "width"], + } + + # Create dataset + dataset = LeRobotDataset.create( + cfg.repo_id, + cfg.fps, + root=cfg.dataset_root, + use_videos=True, + image_writer_threads=4, + image_writer_processes=0, + features=features, + ) + + # Record episodes + episode_index = 0 + recorded_action = None + while episode_index < cfg.num_episodes: + obs, _ = env.reset() + start_episode_t = time.perf_counter() + log_say(f"Recording episode {episode_index}", play_sounds=True) + + # Track success state collection + success_detected = False + success_steps_collected = 0 + + # Run episode steps + while time.perf_counter() - start_episode_t < cfg.wrapper.control_time_s: + start_loop_t = time.perf_counter() + + # Get action from policy if available + if cfg.pretrained_policy_name_or_path is not None: + action = policy.select_action(obs) + + # Step environment + obs, reward, terminated, truncated, info = env.step(action) + + # Check if episode needs to be rerecorded + if info.get("rerecord_episode", False): + break + + # For teleop, get action from intervention + recorded_action = { + "action": info["action_intervention"].cpu().squeeze(0).float() if policy is None else action + } + + # Process observation for dataset + obs_processed = {k: v.cpu().squeeze(0).float() for k, v in obs.items()} + + # Check if we've just detected success + if reward == 1.0 and not success_detected: + success_detected = True + logging.info("Success detected! Collecting additional success states.") + + # Add frame to dataset - continue marking as success even during extra collection steps + frame = {**obs_processed, **recorded_action} + + # If we're in the success collection phase, keep marking rewards as 1.0 + if success_detected: + frame["next.reward"] = np.array([1.0], dtype=np.float32) + else: + frame["next.reward"] = np.array([reward], dtype=np.float32) + + # Only mark as done if we're truly done (reached end or collected enough success states) + really_done = terminated or truncated + if success_detected: + success_steps_collected += 1 + really_done = success_steps_collected >= cfg.number_of_steps_after_success + + frame["next.done"] = np.array([really_done], dtype=bool) + frame["complementary_info.discrete_penalty"] = torch.tensor( + [info.get("discrete_penalty", 0.0)], dtype=torch.float32 + ) + dataset.add_frame(frame, task=cfg.task) + + # Maintain consistent timing + if cfg.fps: + dt_s = time.perf_counter() - start_loop_t + busy_wait(1 / cfg.fps - dt_s) + + # Check if we should end the episode + if (terminated or truncated) and not success_detected: + # Regular termination without success + break + elif success_detected and success_steps_collected >= cfg.number_of_steps_after_success: + # We've collected enough success states + logging.info(f"Collected {success_steps_collected} additional success states") + break + + # Handle episode recording + if info.get("rerecord_episode", False): + dataset.clear_episode_buffer() + logging.info(f"Re-recording episode {episode_index}") + continue + + dataset.save_episode() + episode_index += 1 + + # Finalize dataset + # dataset.consolidate(run_compute_stats=True) + if cfg.push_to_hub: + dataset.push_to_hub() + + +def replay_episode(env, cfg): + """ + Replay a recorded episode in the environment. + + This function loads actions from a previously recorded episode + and executes them in the environment. + + Args: + env: The environment to replay in. + cfg: Configuration object containing replay parameters: + - repo_id: Repository ID for dataset + - dataset_root: Local root directory for dataset + - episode: Episode ID to replay + """ + from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + dataset = LeRobotDataset(cfg.repo_id, root=cfg.dataset_root, episodes=[cfg.episode]) + env.reset() + + actions = dataset.hf_dataset.select_columns("action") + + for idx in range(dataset.num_frames): + start_episode_t = time.perf_counter() + + action = actions[idx]["action"] + env.step(action) + + dt_s = time.perf_counter() - start_episode_t + busy_wait(1 / 10 - dt_s) + + +@parser.wrap() +def main(cfg: EnvConfig): + """Main entry point for the robot environment script. + + This function runs the robot environment in one of several modes + based on the provided configuration. + + Args: + cfg: Configuration object defining the run parameters, + including mode (record, replay, random) and other settings. + """ + env = make_robot_env(cfg) + + if cfg.mode == "record": + policy = None + if cfg.pretrained_policy_name_or_path is not None: + from lerobot.common.policies.sac.modeling_sac import SACPolicy + + policy = SACPolicy.from_pretrained(cfg.pretrained_policy_name_or_path) + policy.to(cfg.device) + policy.eval() + + record_dataset( + env, + policy=policy, + cfg=cfg, + ) + exit() + + if cfg.mode == "replay": + replay_episode( + env, + cfg=cfg, + ) + exit() + + env.reset() + + # Initialize the smoothed action as a random sample. + smoothed_action = env.action_space.sample() * 0.0 + + # Smoothing coefficient (alpha) defines how much of the new random sample to mix in. + # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth. + alpha = 1.0 + + num_episode = 0 + successes = [] + while num_episode < 10: + start_loop_s = time.perf_counter() + # Sample a new random action from the robot's action space. + new_random_action = env.action_space.sample() + # Update the smoothed action using an exponential moving average. + smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action + + # Execute the step: wrap the NumPy action in a torch tensor. + obs, reward, terminated, truncated, info = env.step(smoothed_action) + if terminated or truncated: + successes.append(reward) + env.reset() + num_episode += 1 + + dt_s = time.perf_counter() - start_loop_s + busy_wait(1 / cfg.fps - dt_s) + + logging.info(f"Success after 20 steps {successes}") + logging.info(f"success rate {sum(successes) / len(successes)}") + + +if __name__ == "__main__": + main() diff --git a/lerobot/scripts/rl/learner.py b/lerobot/scripts/rl/learner.py new file mode 100644 index 000000000..663dbe918 --- /dev/null +++ b/lerobot/scripts/rl/learner.py @@ -0,0 +1,1206 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Learner server runner for distributed HILSerl robot policy training. + +This script implements the learner component of the distributed HILSerl architecture. +It initializes the policy network, maintains replay buffers, and updates +the policy based on transitions received from the actor server. + +Examples of usage: + +- Start a learner server for training: +```bash +python lerobot/scripts/rl/learner.py --config_path lerobot/configs/train_config_hilserl_so100.json +``` + +**NOTE**: Start the learner server before launching the actor server. The learner opens a gRPC server +to communicate with actors. + +**NOTE**: Training progress can be monitored through Weights & Biases if wandb.enable is set to true +in your configuration. + +**WORKFLOW**: +1. Create training configuration with proper policy, dataset, and environment settings +2. Start this learner server with the configuration +3. Start an actor server with the same configuration +4. Monitor training progress through wandb dashboard + +For more details on the complete HILSerl training workflow, see: +https://github.com/michel-aractingi/lerobot-hilserl-guide +""" + +import logging +import os +import shutil +import time +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from pprint import pformat + +import grpc +import torch +from termcolor import colored +from torch import nn +from torch.multiprocessing import Queue +from torch.optim.optimizer import Optimizer + +from lerobot.common.cameras import opencv # noqa: F401 +from lerobot.common.constants import ( + CHECKPOINTS_DIR, + LAST_CHECKPOINT_LINK, + PRETRAINED_MODEL_DIR, + TRAINING_STATE_DIR, +) +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.sac.modeling_sac import SACPolicy +from lerobot.common.robots import so100_follower # noqa: F401 +from lerobot.common.teleoperators import gamepad, so101_leader # noqa: F401 +from lerobot.common.transport import services_pb2_grpc +from lerobot.common.transport.utils import ( + bytes_to_python_object, + bytes_to_transitions, + state_to_bytes, +) +from lerobot.common.utils.buffer import ReplayBuffer, concatenate_batch_transitions +from lerobot.common.utils.process import ProcessSignalHandler +from lerobot.common.utils.random_utils import set_seed +from lerobot.common.utils.train_utils import ( + get_step_checkpoint_dir, + save_checkpoint, + update_last_checkpoint, +) +from lerobot.common.utils.train_utils import ( + load_training_state as utils_load_training_state, +) +from lerobot.common.utils.transition import move_state_dict_to_device, move_transition_to_device +from lerobot.common.utils.utils import ( + format_big_number, + get_safe_torch_device, + init_logging, +) +from lerobot.common.utils.wandb_utils import WandBLogger +from lerobot.configs import parser +from lerobot.configs.train import TrainRLServerPipelineConfig +from lerobot.scripts.rl import learner_service + +LOG_PREFIX = "[LEARNER]" + + +################################################# +# MAIN ENTRY POINTS AND CORE ALGORITHM FUNCTIONS # +################################################# + + +@parser.wrap() +def train_cli(cfg: TrainRLServerPipelineConfig): + if not use_threads(cfg): + import torch.multiprocessing as mp + + mp.set_start_method("spawn") + + # Use the job_name from the config + train( + cfg, + job_name=cfg.job_name, + ) + + logging.info("[LEARNER] train_cli finished") + + +def train(cfg: TrainRLServerPipelineConfig, job_name: str | None = None): + """ + Main training function that initializes and runs the training process. + + Args: + cfg (TrainRLServerPipelineConfig): The training configuration + job_name (str | None, optional): Job name for logging. Defaults to None. + """ + + cfg.validate() + + if job_name is None: + job_name = cfg.job_name + + if job_name is None: + raise ValueError("Job name must be specified either in config or as a parameter") + + display_pid = False + if not use_threads(cfg): + display_pid = True + + # Create logs directory to ensure it exists + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"learner_{job_name}.log") + + # Initialize logging with explicit log file + init_logging(log_file=log_file, display_pid=display_pid) + logging.info(f"Learner logging initialized, writing to {log_file}") + logging.info(pformat(cfg.to_dict())) + + # Setup WandB logging if enabled + if cfg.wandb.enable and cfg.wandb.project: + from lerobot.common.utils.wandb_utils import WandBLogger + + wandb_logger = WandBLogger(cfg) + else: + wandb_logger = None + logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) + + # Handle resume logic + cfg = handle_resume_logic(cfg) + + set_seed(seed=cfg.seed) + + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + + is_threaded = use_threads(cfg) + shutdown_event = ProcessSignalHandler(is_threaded, display_pid=display_pid).shutdown_event + + start_learner_threads( + cfg=cfg, + wandb_logger=wandb_logger, + shutdown_event=shutdown_event, + ) + + +def start_learner_threads( + cfg: TrainRLServerPipelineConfig, + wandb_logger: WandBLogger | None, + shutdown_event: any, # Event, +) -> None: + """ + Start the learner threads for training. + + Args: + cfg (TrainRLServerPipelineConfig): Training configuration + wandb_logger (WandBLogger | None): Logger for metrics + shutdown_event: Event to signal shutdown + """ + # Create multiprocessing queues + transition_queue = Queue() + interaction_message_queue = Queue() + parameters_queue = Queue() + + concurrency_entity = None + + if use_threads(cfg): + from threading import Thread + + concurrency_entity = Thread + else: + from torch.multiprocessing import Process + + concurrency_entity = Process + + communication_process = concurrency_entity( + target=start_learner, + args=( + parameters_queue, + transition_queue, + interaction_message_queue, + shutdown_event, + cfg, + ), + daemon=True, + ) + communication_process.start() + + add_actor_information_and_train( + cfg=cfg, + wandb_logger=wandb_logger, + shutdown_event=shutdown_event, + transition_queue=transition_queue, + interaction_message_queue=interaction_message_queue, + parameters_queue=parameters_queue, + ) + logging.info("[LEARNER] Training process stopped") + + logging.info("[LEARNER] Closing queues") + transition_queue.close() + interaction_message_queue.close() + parameters_queue.close() + + communication_process.join() + logging.info("[LEARNER] Communication process joined") + + logging.info("[LEARNER] join queues") + transition_queue.cancel_join_thread() + interaction_message_queue.cancel_join_thread() + parameters_queue.cancel_join_thread() + + logging.info("[LEARNER] queues closed") + + +################################################# +# Core algorithm functions # +################################################# + + +def add_actor_information_and_train( + cfg: TrainRLServerPipelineConfig, + wandb_logger: WandBLogger | None, + shutdown_event: any, # Event, + transition_queue: Queue, + interaction_message_queue: Queue, + parameters_queue: Queue, +): + """ + Handles data transfer from the actor to the learner, manages training updates, + and logs training progress in an online reinforcement learning setup. + + This function continuously: + - Transfers transitions from the actor to the replay buffer. + - Logs received interaction messages. + - Ensures training begins only when the replay buffer has a sufficient number of transitions. + - Samples batches from the replay buffer and performs multiple critic updates. + - Periodically updates the actor, critic, and temperature optimizers. + - Logs training statistics, including loss values and optimization frequency. + + NOTE: This function doesn't have a single responsibility, it should be split into multiple functions + in the future. The reason why we did that is the GIL in Python. It's super slow the performance + are divided by 200. So we need to have a single thread that does all the work. + + Args: + cfg (TrainRLServerPipelineConfig): Configuration object containing hyperparameters. + wandb_logger (WandBLogger | None): Logger for tracking training progress. + shutdown_event (Event): Event to signal shutdown. + transition_queue (Queue): Queue for receiving transitions from the actor. + interaction_message_queue (Queue): Queue for receiving interaction messages from the actor. + parameters_queue (Queue): Queue for sending policy parameters to the actor. + """ + # Extract all configuration variables at the beginning, it improve the speed performance + # of 7% + device = get_safe_torch_device(try_device=cfg.policy.device, log=True) + storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device) + clip_grad_norm_value = cfg.policy.grad_clip_norm + online_step_before_learning = cfg.policy.online_step_before_learning + utd_ratio = cfg.policy.utd_ratio + fps = cfg.env.fps + log_freq = cfg.log_freq + save_freq = cfg.save_freq + policy_update_freq = cfg.policy.policy_update_freq + policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency + saving_checkpoint = cfg.save_checkpoint + online_steps = cfg.policy.online_steps + async_prefetch = cfg.policy.async_prefetch + + # Initialize logging for multiprocessing + if not use_threads(cfg): + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"learner_train_process_{os.getpid()}.log") + init_logging(log_file=log_file, display_pid=True) + logging.info("Initialized logging for actor information and training process") + + logging.info("Initializing policy") + + policy: SACPolicy = make_policy( + cfg=cfg.policy, + env_cfg=cfg.env, + ) + + assert isinstance(policy, nn.Module) + + policy.train() + + push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) + + last_time_policy_pushed = time.time() + + optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy) + + # If we are resuming, we need to load the training state + resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers) + + log_training_info(cfg=cfg, policy=policy) + + replay_buffer = initialize_replay_buffer(cfg, device, storage_device) + batch_size = cfg.batch_size + offline_replay_buffer = None + + if cfg.dataset is not None: + offline_replay_buffer = initialize_offline_replay_buffer( + cfg=cfg, + device=device, + storage_device=storage_device, + ) + batch_size: int = batch_size // 2 # We will sample from both replay buffer + + logging.info("Starting learner thread") + interaction_message = None + optimization_step = resume_optimization_step if resume_optimization_step is not None else 0 + interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0 + + dataset_repo_id = None + if cfg.dataset is not None: + dataset_repo_id = cfg.dataset.repo_id + + # Initialize iterators + online_iterator = None + offline_iterator = None + + # NOTE: THIS IS THE MAIN LOOP OF THE LEARNER + while True: + # Exit the training loop if shutdown is requested + if shutdown_event is not None and shutdown_event.is_set(): + logging.info("[LEARNER] Shutdown signal received. Exiting...") + break + + # Process all available transitions to the replay buffer, send by the actor server + process_transitions( + transition_queue=transition_queue, + replay_buffer=replay_buffer, + offline_replay_buffer=offline_replay_buffer, + device=device, + dataset_repo_id=dataset_repo_id, + shutdown_event=shutdown_event, + ) + + # Process all available interaction messages sent by the actor server + interaction_message = process_interaction_messages( + interaction_message_queue=interaction_message_queue, + interaction_step_shift=interaction_step_shift, + wandb_logger=wandb_logger, + shutdown_event=shutdown_event, + ) + + # Wait until the replay buffer has enough samples to start training + if len(replay_buffer) < online_step_before_learning: + continue + + if online_iterator is None: + online_iterator = replay_buffer.get_iterator( + batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 + ) + + if offline_replay_buffer is not None and offline_iterator is None: + offline_iterator = offline_replay_buffer.get_iterator( + batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2 + ) + + time_for_one_optimization_step = time.time() + for _ in range(utd_ratio - 1): + # Sample from the iterators + batch = next(online_iterator) + + if dataset_repo_id is not None: + batch_offline = next(offline_iterator) + batch = concatenate_batch_transitions( + left_batch_transitions=batch, right_batch_transition=batch_offline + ) + + actions = batch["action"] + rewards = batch["reward"] + observations = batch["state"] + next_observations = batch["next_state"] + done = batch["done"] + check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) + + observation_features, next_observation_features = get_observation_features( + policy=policy, observations=observations, next_observations=next_observations + ) + + # Create a batch dictionary with all required elements for the forward method + forward_batch = { + "action": actions, + "reward": rewards, + "state": observations, + "next_state": next_observations, + "done": done, + "observation_feature": observation_features, + "next_observation_feature": next_observation_features, + "complementary_info": batch["complementary_info"], + } + + # Use the forward method for critic loss + critic_output = policy.forward(forward_batch, model="critic") + + # Main critic optimization + loss_critic = critic_output["loss_critic"] + optimizers["critic"].zero_grad() + loss_critic.backward() + critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value + ) + optimizers["critic"].step() + + # Discrete critic optimization (if available) + if policy.config.num_discrete_actions is not None: + discrete_critic_output = policy.forward(forward_batch, model="discrete_critic") + loss_discrete_critic = discrete_critic_output["loss_discrete_critic"] + optimizers["discrete_critic"].zero_grad() + loss_discrete_critic.backward() + discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value + ) + optimizers["discrete_critic"].step() + + # Update target networks (main and discrete) + policy.update_target_networks() + + # Sample for the last update in the UTD ratio + batch = next(online_iterator) + + if dataset_repo_id is not None: + batch_offline = next(offline_iterator) + batch = concatenate_batch_transitions( + left_batch_transitions=batch, right_batch_transition=batch_offline + ) + + actions = batch["action"] + rewards = batch["reward"] + observations = batch["state"] + next_observations = batch["next_state"] + done = batch["done"] + + check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations) + + observation_features, next_observation_features = get_observation_features( + policy=policy, observations=observations, next_observations=next_observations + ) + + # Create a batch dictionary with all required elements for the forward method + forward_batch = { + "action": actions, + "reward": rewards, + "state": observations, + "next_state": next_observations, + "done": done, + "observation_feature": observation_features, + "next_observation_feature": next_observation_features, + } + + critic_output = policy.forward(forward_batch, model="critic") + + loss_critic = critic_output["loss_critic"] + optimizers["critic"].zero_grad() + loss_critic.backward() + critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value + ).item() + optimizers["critic"].step() + + # Initialize training info dictionary + training_infos = { + "loss_critic": loss_critic.item(), + "critic_grad_norm": critic_grad_norm, + } + + # Discrete critic optimization (if available) + if policy.config.num_discrete_actions is not None: + discrete_critic_output = policy.forward(forward_batch, model="discrete_critic") + loss_discrete_critic = discrete_critic_output["loss_discrete_critic"] + optimizers["discrete_critic"].zero_grad() + loss_discrete_critic.backward() + discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value + ).item() + optimizers["discrete_critic"].step() + + # Add discrete critic info to training info + training_infos["loss_discrete_critic"] = loss_discrete_critic.item() + training_infos["discrete_critic_grad_norm"] = discrete_critic_grad_norm + + # Actor and temperature optimization (at specified frequency) + if optimization_step % policy_update_freq == 0: + for _ in range(policy_update_freq): + # Actor optimization + actor_output = policy.forward(forward_batch, model="actor") + loss_actor = actor_output["loss_actor"] + optimizers["actor"].zero_grad() + loss_actor.backward() + actor_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value + ).item() + optimizers["actor"].step() + + # Add actor info to training info + training_infos["loss_actor"] = loss_actor.item() + training_infos["actor_grad_norm"] = actor_grad_norm + + # Temperature optimization + temperature_output = policy.forward(forward_batch, model="temperature") + loss_temperature = temperature_output["loss_temperature"] + optimizers["temperature"].zero_grad() + loss_temperature.backward() + temp_grad_norm = torch.nn.utils.clip_grad_norm_( + parameters=[policy.log_alpha], max_norm=clip_grad_norm_value + ).item() + optimizers["temperature"].step() + + # Add temperature info to training info + training_infos["loss_temperature"] = loss_temperature.item() + training_infos["temperature_grad_norm"] = temp_grad_norm + training_infos["temperature"] = policy.temperature + + # Update temperature + policy.update_temperature() + + # Push policy to actors if needed + if time.time() - last_time_policy_pushed > policy_parameters_push_frequency: + push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy) + last_time_policy_pushed = time.time() + + # Update target networks (main and discrete) + policy.update_target_networks() + + # Log training metrics at specified intervals + if optimization_step % log_freq == 0: + training_infos["replay_buffer_size"] = len(replay_buffer) + if offline_replay_buffer is not None: + training_infos["offline_replay_buffer_size"] = len(offline_replay_buffer) + training_infos["Optimization step"] = optimization_step + + # Log training metrics + if wandb_logger: + wandb_logger.log_dict(d=training_infos, mode="train", custom_step_key="Optimization step") + + # Calculate and log optimization frequency + time_for_one_optimization_step = time.time() - time_for_one_optimization_step + frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9) + + logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}") + + # Log optimization frequency + if wandb_logger: + wandb_logger.log_dict( + { + "Optimization frequency loop [Hz]": frequency_for_one_optimization_step, + "Optimization step": optimization_step, + }, + mode="train", + custom_step_key="Optimization step", + ) + + optimization_step += 1 + if optimization_step % log_freq == 0: + logging.info(f"[LEARNER] Number of optimization step: {optimization_step}") + + # Save checkpoint at specified intervals + if saving_checkpoint and (optimization_step % save_freq == 0 or optimization_step == online_steps): + save_training_checkpoint( + cfg=cfg, + optimization_step=optimization_step, + online_steps=online_steps, + interaction_message=interaction_message, + policy=policy, + optimizers=optimizers, + replay_buffer=replay_buffer, + offline_replay_buffer=offline_replay_buffer, + dataset_repo_id=dataset_repo_id, + fps=fps, + ) + + +def start_learner( + parameters_queue: Queue, + transition_queue: Queue, + interaction_message_queue: Queue, + shutdown_event: any, # Event, + cfg: TrainRLServerPipelineConfig, +): + """ + Start the learner server for training. + It will receive transitions and interaction messages from the actor server, + and send policy parameters to the actor server. + + Args: + parameters_queue: Queue for sending policy parameters to the actor + transition_queue: Queue for receiving transitions from the actor + interaction_message_queue: Queue for receiving interaction messages from the actor + shutdown_event: Event to signal shutdown + cfg: Training configuration + """ + if not use_threads(cfg): + # Create a process-specific log file + log_dir = os.path.join(cfg.output_dir, "logs") + os.makedirs(log_dir, exist_ok=True) + log_file = os.path.join(log_dir, f"learner_process_{os.getpid()}.log") + + # Initialize logging with explicit log file + init_logging(log_file=log_file, display_pid=True) + logging.info("Learner server process logging initialized") + + # Setup process handlers to handle shutdown signal + # But use shutdown event from the main process + # Return back for MP + # TODO: Check if its useful + _ = ProcessSignalHandler(False, display_pid=True) + + service = learner_service.LearnerService( + shutdown_event=shutdown_event, + parameters_queue=parameters_queue, + seconds_between_pushes=cfg.policy.actor_learner_config.policy_parameters_push_frequency, + transition_queue=transition_queue, + interaction_message_queue=interaction_message_queue, + queue_get_timeout=cfg.policy.actor_learner_config.queue_get_timeout, + ) + + server = grpc.server( + ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS), + options=[ + ("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE), + ("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE), + ], + ) + + services_pb2_grpc.add_LearnerServiceServicer_to_server( + service, + server, + ) + + host = cfg.policy.actor_learner_config.learner_host + port = cfg.policy.actor_learner_config.learner_port + + server.add_insecure_port(f"{host}:{port}") + server.start() + logging.info("[LEARNER] gRPC server started") + + shutdown_event.wait() + logging.info("[LEARNER] Stopping gRPC server...") + server.stop(learner_service.SHUTDOWN_TIMEOUT) + logging.info("[LEARNER] gRPC server stopped") + + +def save_training_checkpoint( + cfg: TrainRLServerPipelineConfig, + optimization_step: int, + online_steps: int, + interaction_message: dict | None, + policy: nn.Module, + optimizers: dict[str, Optimizer], + replay_buffer: ReplayBuffer, + offline_replay_buffer: ReplayBuffer | None = None, + dataset_repo_id: str | None = None, + fps: int = 30, +) -> None: + """ + Save training checkpoint and associated data. + + This function performs the following steps: + 1. Creates a checkpoint directory with the current optimization step + 2. Saves the policy model, configuration, and optimizer states + 3. Saves the current interaction step for resuming training + 4. Updates the "last" checkpoint symlink to point to this checkpoint + 5. Saves the replay buffer as a dataset for later use + 6. If an offline replay buffer exists, saves it as a separate dataset + + Args: + cfg: Training configuration + optimization_step: Current optimization step + online_steps: Total number of online steps + interaction_message: Dictionary containing interaction information + policy: Policy model to save + optimizers: Dictionary of optimizers + replay_buffer: Replay buffer to save as dataset + offline_replay_buffer: Optional offline replay buffer to save + dataset_repo_id: Repository ID for dataset + fps: Frames per second for dataset + """ + logging.info(f"Checkpoint policy after step {optimization_step}") + _num_digits = max(6, len(str(online_steps))) + interaction_step = interaction_message["Interaction step"] if interaction_message is not None else 0 + + # Create checkpoint directory + checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, online_steps, optimization_step) + + # Save checkpoint + save_checkpoint( + checkpoint_dir=checkpoint_dir, + step=optimization_step, + cfg=cfg, + policy=policy, + optimizer=optimizers, + scheduler=None, + ) + + # Save interaction step manually + training_state_dir = os.path.join(checkpoint_dir, TRAINING_STATE_DIR) + os.makedirs(training_state_dir, exist_ok=True) + training_state = {"step": optimization_step, "interaction_step": interaction_step} + torch.save(training_state, os.path.join(training_state_dir, "training_state.pt")) + + # Update the "last" symlink + update_last_checkpoint(checkpoint_dir) + + # TODO : temporary save replay buffer here, remove later when on the robot + # We want to control this with the keyboard inputs + dataset_dir = os.path.join(cfg.output_dir, "dataset") + if os.path.exists(dataset_dir) and os.path.isdir(dataset_dir): + shutil.rmtree(dataset_dir) + + # Save dataset + # NOTE: Handle the case where the dataset repo id is not specified in the config + # eg. RL training without demonstrations data + repo_id_buffer_save = cfg.env.task if dataset_repo_id is None else dataset_repo_id + replay_buffer.to_lerobot_dataset(repo_id=repo_id_buffer_save, fps=fps, root=dataset_dir) + + if offline_replay_buffer is not None: + dataset_offline_dir = os.path.join(cfg.output_dir, "dataset_offline") + if os.path.exists(dataset_offline_dir) and os.path.isdir(dataset_offline_dir): + shutil.rmtree(dataset_offline_dir) + + offline_replay_buffer.to_lerobot_dataset( + cfg.dataset.repo_id, + fps=fps, + root=dataset_offline_dir, + ) + + logging.info("Resume training") + + +def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.Module): + """ + Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy. + + This function sets up Adam optimizers for: + - The **actor network**, ensuring that only relevant parameters are optimized. + - The **critic ensemble**, which evaluates the value function. + - The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods. + + It also initializes a learning rate scheduler, though currently, it is set to `None`. + + NOTE: + - If the encoder is shared, its parameters are excluded from the actor's optimization process. + - The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor. + + Args: + cfg: Configuration object containing hyperparameters. + policy (nn.Module): The policy model containing the actor, critic, and temperature components. + + Returns: + Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]: + A tuple containing: + - `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers. + - `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling. + + """ + optimizer_actor = torch.optim.Adam( + params=[ + p + for n, p in policy.actor.named_parameters() + if not policy.config.shared_encoder or not n.startswith("encoder") + ], + lr=cfg.policy.actor_lr, + ) + optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr) + + if cfg.policy.num_discrete_actions is not None: + optimizer_discrete_critic = torch.optim.Adam( + params=policy.discrete_critic.parameters(), lr=cfg.policy.critic_lr + ) + optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr) + lr_scheduler = None + optimizers = { + "actor": optimizer_actor, + "critic": optimizer_critic, + "temperature": optimizer_temperature, + } + if cfg.policy.num_discrete_actions is not None: + optimizers["discrete_critic"] = optimizer_discrete_critic + return optimizers, lr_scheduler + + +################################################# +# Training setup functions # +################################################# + + +def handle_resume_logic(cfg: TrainRLServerPipelineConfig) -> TrainRLServerPipelineConfig: + """ + Handle the resume logic for training. + + If resume is True: + - Verifies that a checkpoint exists + - Loads the checkpoint configuration + - Logs resumption details + - Returns the checkpoint configuration + + If resume is False: + - Checks if an output directory exists (to prevent accidental overwriting) + - Returns the original configuration + + Args: + cfg (TrainRLServerPipelineConfig): The training configuration + + Returns: + TrainRLServerPipelineConfig: The updated configuration + + Raises: + RuntimeError: If resume is True but no checkpoint found, or if resume is False but directory exists + """ + out_dir = cfg.output_dir + + # Case 1: Not resuming, but need to check if directory exists to prevent overwrites + if not cfg.resume: + checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) + if os.path.exists(checkpoint_dir): + raise RuntimeError( + f"Output directory {checkpoint_dir} already exists. Use `resume=true` to resume training." + ) + return cfg + + # Case 2: Resuming training + checkpoint_dir = os.path.join(out_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) + if not os.path.exists(checkpoint_dir): + raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True") + + # Log that we found a valid checkpoint and are resuming + logging.info( + colored( + "Valid checkpoint found: resume=True detected, resuming previous run", + color="yellow", + attrs=["bold"], + ) + ) + + # Load config using Draccus + checkpoint_cfg_path = os.path.join(checkpoint_dir, PRETRAINED_MODEL_DIR, "train_config.json") + checkpoint_cfg = TrainRLServerPipelineConfig.from_pretrained(checkpoint_cfg_path) + + # Ensure resume flag is set in returned config + checkpoint_cfg.resume = True + return checkpoint_cfg + + +def load_training_state( + cfg: TrainRLServerPipelineConfig, + optimizers: Optimizer | dict[str, Optimizer], +): + """ + Loads the training state (optimizers, step count, etc.) from a checkpoint. + + Args: + cfg (TrainRLServerPipelineConfig): Training configuration + optimizers (Optimizer | dict): Optimizers to load state into + + Returns: + tuple: (optimization_step, interaction_step) or (None, None) if not resuming + """ + if not cfg.resume: + return None, None + + # Construct path to the last checkpoint directory + checkpoint_dir = os.path.join(cfg.output_dir, CHECKPOINTS_DIR, LAST_CHECKPOINT_LINK) + + logging.info(f"Loading training state from {checkpoint_dir}") + + try: + # Use the utility function from train_utils which loads the optimizer state + step, optimizers, _ = utils_load_training_state(Path(checkpoint_dir), optimizers, None) + + # Load interaction step separately from training_state.pt + training_state_path = os.path.join(checkpoint_dir, TRAINING_STATE_DIR, "training_state.pt") + interaction_step = 0 + if os.path.exists(training_state_path): + training_state = torch.load(training_state_path, weights_only=False) # nosec B614: Safe usage of torch.load + interaction_step = training_state.get("interaction_step", 0) + + logging.info(f"Resuming from step {step}, interaction step {interaction_step}") + return step, interaction_step + + except Exception as e: + logging.error(f"Failed to load training state: {e}") + return None, None + + +def log_training_info(cfg: TrainRLServerPipelineConfig, policy: nn.Module) -> None: + """ + Log information about the training process. + + Args: + cfg (TrainRLServerPipelineConfig): Training configuration + policy (nn.Module): Policy model + """ + num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) + num_total_params = sum(p.numel() for p in policy.parameters()) + + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") + logging.info(f"{cfg.env.task=}") + logging.info(f"{cfg.policy.online_steps=}") + logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") + logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") + + +def initialize_replay_buffer( + cfg: TrainRLServerPipelineConfig, device: str, storage_device: str +) -> ReplayBuffer: + """ + Initialize a replay buffer, either empty or from a dataset if resuming. + + Args: + cfg (TrainRLServerPipelineConfig): Training configuration + device (str): Device to store tensors on + storage_device (str): Device for storage optimization + + Returns: + ReplayBuffer: Initialized replay buffer + """ + if not cfg.resume: + return ReplayBuffer( + capacity=cfg.policy.online_buffer_capacity, + device=device, + state_keys=cfg.policy.input_features.keys(), + storage_device=storage_device, + optimize_memory=True, + ) + + logging.info("Resume training load the online dataset") + dataset_path = os.path.join(cfg.output_dir, "dataset") + + # NOTE: In RL is possible to not have a dataset. + repo_id = None + if cfg.dataset is not None: + repo_id = cfg.dataset.repo_id + dataset = LeRobotDataset( + repo_id=repo_id, + root=dataset_path, + ) + return ReplayBuffer.from_lerobot_dataset( + lerobot_dataset=dataset, + capacity=cfg.policy.online_buffer_capacity, + device=device, + state_keys=cfg.policy.input_features.keys(), + optimize_memory=True, + ) + + +def initialize_offline_replay_buffer( + cfg: TrainRLServerPipelineConfig, + device: str, + storage_device: str, +) -> ReplayBuffer: + """ + Initialize an offline replay buffer from a dataset. + + Args: + cfg (TrainRLServerPipelineConfig): Training configuration + device (str): Device to store tensors on + storage_device (str): Device for storage optimization + + Returns: + ReplayBuffer: Initialized offline replay buffer + """ + if not cfg.resume: + logging.info("make_dataset offline buffer") + offline_dataset = make_dataset(cfg) + else: + logging.info("load offline dataset") + dataset_offline_path = os.path.join(cfg.output_dir, "dataset_offline") + offline_dataset = LeRobotDataset( + repo_id=cfg.dataset.repo_id, + root=dataset_offline_path, + ) + + logging.info("Convert to a offline replay buffer") + offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( + offline_dataset, + device=device, + state_keys=cfg.policy.input_features.keys(), + storage_device=storage_device, + optimize_memory=True, + capacity=cfg.policy.offline_buffer_capacity, + ) + return offline_replay_buffer + + +################################################# +# Utilities/Helpers functions # +################################################# + + +def get_observation_features( + policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor +) -> tuple[torch.Tensor | None, torch.Tensor | None]: + """ + Get observation features from the policy encoder. It act as cache for the observation features. + when the encoder is frozen, the observation features are not updated. + We can save compute by caching the observation features. + + Args: + policy: The policy model + observations: The current observations + next_observations: The next observations + + Returns: + tuple: observation_features, next_observation_features + """ + + if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder: + return None, None + + with torch.no_grad(): + observation_features = policy.actor.encoder.get_cached_image_features(observations, normalize=True) + next_observation_features = policy.actor.encoder.get_cached_image_features( + next_observations, normalize=True + ) + + return observation_features, next_observation_features + + +def use_threads(cfg: TrainRLServerPipelineConfig) -> bool: + return cfg.policy.concurrency.learner == "threads" + + +def check_nan_in_transition( + observations: torch.Tensor, + actions: torch.Tensor, + next_state: torch.Tensor, + raise_error: bool = False, +) -> bool: + """ + Check for NaN values in transition data. + + Args: + observations: Dictionary of observation tensors + actions: Action tensor + next_state: Dictionary of next state tensors + raise_error: If True, raises ValueError when NaN is detected + + Returns: + bool: True if NaN values were detected, False otherwise + """ + nan_detected = False + + # Check observations + for key, tensor in observations.items(): + if torch.isnan(tensor).any(): + logging.error(f"observations[{key}] contains NaN values") + nan_detected = True + if raise_error: + raise ValueError(f"NaN detected in observations[{key}]") + + # Check next state + for key, tensor in next_state.items(): + if torch.isnan(tensor).any(): + logging.error(f"next_state[{key}] contains NaN values") + nan_detected = True + if raise_error: + raise ValueError(f"NaN detected in next_state[{key}]") + + # Check actions + if torch.isnan(actions).any(): + logging.error("actions contains NaN values") + nan_detected = True + if raise_error: + raise ValueError("NaN detected in actions") + + return nan_detected + + +def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module): + logging.debug("[LEARNER] Pushing actor policy to the queue") + state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu") + state_bytes = state_to_bytes(state_dict) + parameters_queue.put(state_bytes) + + +def process_interaction_message( + message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None +): + """Process a single interaction message with consistent handling.""" + message = bytes_to_python_object(message) + # Shift interaction step for consistency with checkpointed state + message["Interaction step"] += interaction_step_shift + + # Log if logger available + if wandb_logger: + wandb_logger.log_dict(d=message, mode="train", custom_step_key="Interaction step") + + return message + + +def process_transitions( + transition_queue: Queue, + replay_buffer: ReplayBuffer, + offline_replay_buffer: ReplayBuffer, + device: str, + dataset_repo_id: str | None, + shutdown_event: any, +): + """Process all available transitions from the queue. + + Args: + transition_queue: Queue for receiving transitions from the actor + replay_buffer: Replay buffer to add transitions to + offline_replay_buffer: Offline replay buffer to add transitions to + device: Device to move transitions to + dataset_repo_id: Repository ID for dataset + shutdown_event: Event to signal shutdown + """ + while not transition_queue.empty() and not shutdown_event.is_set(): + transition_list = transition_queue.get() + transition_list = bytes_to_transitions(buffer=transition_list) + + for transition in transition_list: + transition = move_transition_to_device(transition=transition, device=device) + + # Skip transitions with NaN values + if check_nan_in_transition( + observations=transition["state"], + actions=transition["action"], + next_state=transition["next_state"], + ): + logging.warning("[LEARNER] NaN detected in transition, skipping") + continue + + replay_buffer.add(**transition) + + # Add to offline buffer if it's an intervention + if dataset_repo_id is not None and transition.get("complementary_info", {}).get( + "is_intervention" + ): + offline_replay_buffer.add(**transition) + + +def process_interaction_messages( + interaction_message_queue: Queue, + interaction_step_shift: int, + wandb_logger: WandBLogger | None, + shutdown_event: any, +) -> dict | None: + """Process all available interaction messages from the queue. + + Args: + interaction_message_queue: Queue for receiving interaction messages + interaction_step_shift: Amount to shift interaction step by + wandb_logger: Logger for tracking progress + shutdown_event: Event to signal shutdown + + Returns: + dict | None: The last interaction message processed, or None if none were processed + """ + last_message = None + while not interaction_message_queue.empty() and not shutdown_event.is_set(): + message = interaction_message_queue.get() + last_message = process_interaction_message( + message=message, + interaction_step_shift=interaction_step_shift, + wandb_logger=wandb_logger, + ) + + return last_message + + +if __name__ == "__main__": + train_cli() + logging.info("[LEARNER] main finished") diff --git a/lerobot/scripts/rl/learner_service.py b/lerobot/scripts/rl/learner_service.py new file mode 100644 index 000000000..f967d812c --- /dev/null +++ b/lerobot/scripts/rl/learner_service.py @@ -0,0 +1,118 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from multiprocessing import Event, Queue + +from lerobot.common.transport import services_pb2, services_pb2_grpc +from lerobot.common.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks +from lerobot.common.utils.queue import get_last_item_from_queue + +MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB +MAX_WORKERS = 3 # Stream parameters, send transitions and interactions +SHUTDOWN_TIMEOUT = 10 + + +class LearnerService(services_pb2_grpc.LearnerServiceServicer): + """ + Implementation of the LearnerService gRPC service + This service is used to send parameters to the Actor and receive transitions and interactions from the Actor + check transport.proto for the gRPC service definition + """ + + def __init__( + self, + shutdown_event: Event, # type: ignore + parameters_queue: Queue, + seconds_between_pushes: float, + transition_queue: Queue, + interaction_message_queue: Queue, + queue_get_timeout: float = 0.001, + ): + self.shutdown_event = shutdown_event + self.parameters_queue = parameters_queue + self.seconds_between_pushes = seconds_between_pushes + self.transition_queue = transition_queue + self.interaction_message_queue = interaction_message_queue + self.queue_get_timeout = queue_get_timeout + + def StreamParameters(self, request, context): # noqa: N802 + # TODO: authorize the request + logging.info("[LEARNER] Received request to stream parameters from the Actor") + + last_push_time = 0 + + while not self.shutdown_event.is_set(): + time_since_last_push = time.time() - last_push_time + if time_since_last_push < self.seconds_between_pushes: + self.shutdown_event.wait(self.seconds_between_pushes - time_since_last_push) + # Continue, because we could receive a shutdown event, + # and it's checked in the while loop + continue + + logging.info("[LEARNER] Push parameters to the Actor") + buffer = get_last_item_from_queue( + self.parameters_queue, block=True, timeout=self.queue_get_timeout + ) + + if buffer is None: + continue + + yield from send_bytes_in_chunks( + buffer, + services_pb2.Parameters, + log_prefix="[LEARNER] Sending parameters", + silent=True, + ) + + last_push_time = time.time() + logging.info("[LEARNER] Parameters sent") + + logging.info("[LEARNER] Stream parameters finished") + return services_pb2.Empty() + + def SendTransitions(self, request_iterator, _context): # noqa: N802 + # TODO: authorize the request + logging.info("[LEARNER] Received request to receive transitions from the Actor") + + receive_bytes_in_chunks( + request_iterator, + self.transition_queue, + self.shutdown_event, + log_prefix="[LEARNER] transitions", + ) + + logging.debug("[LEARNER] Finished receiving transitions") + return services_pb2.Empty() + + def SendInteractions(self, request_iterator, _context): # noqa: N802 + # TODO: authorize the request + logging.info("[LEARNER] Received request to receive interactions from the Actor") + + receive_bytes_in_chunks( + request_iterator, + self.interaction_message_queue, + self.shutdown_event, + log_prefix="[LEARNER] interactions", + ) + + logging.debug("[LEARNER] Finished receiving interactions") + return services_pb2.Empty() + + def Ready(self, request, context): # noqa: N802 + return services_pb2.Empty() diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index 3b20691f3..683ae8493 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -283,6 +283,9 @@ def train(cfg: TrainPipelineConfig): eval_env.close() logging.info("End of training") + if cfg.policy.push_to_hub: + policy.push_model_to_hub(cfg) + if __name__ == "__main__": init_logging() diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index f48559172..d4f0d80f8 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -174,7 +174,10 @@ def run_server( dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys ] videos_info = [ - {"url": url_for("static", filename=video_path), "filename": video_path.parent.name} + { + "url": url_for("static", filename=str(video_path).replace("\\", "/")), + "filename": video_path.parent.name, + } for video_path in video_paths ] tasks = dataset.meta.episodes[episode_id]["tasks"] @@ -381,7 +384,7 @@ def visualize_dataset_html( if isinstance(dataset, LeRobotDataset): ln_videos_dir = static_dir / "videos" if not ln_videos_dir.exists(): - ln_videos_dir.symlink_to((dataset.root / "videos").resolve()) + ln_videos_dir.symlink_to((dataset.root / "videos").resolve().as_posix()) if serve: run_server(dataset, episodes, host, port, static_dir, template_dir) diff --git a/lerobot/setup_motors.py b/lerobot/setup_motors.py new file mode 100644 index 000000000..7909dc68d --- /dev/null +++ b/lerobot/setup_motors.py @@ -0,0 +1,84 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Helper to set motor ids and baudrate. + +Example: + +```shell +python -m lerobot.setup_motors \ + --teleop.type=so100_leader \ + --teleop.port=/dev/tty.usbmodem575E0031751 +``` +""" + +from dataclasses import dataclass + +import draccus + +from .common.robots import ( # noqa: F401 + RobotConfig, + koch_follower, + lekiwi, + make_robot_from_config, + so100_follower, + so101_follower, +) +from .common.teleoperators import ( # noqa: F401 + TeleoperatorConfig, + koch_leader, + make_teleoperator_from_config, + so100_leader, + so101_leader, +) + +COMPATIBLE_DEVICES = [ + "koch_follower", + "koch_leader", + "so100_follower", + "so100_leader", + "so101_follower", + "so101_leader", + "lekiwi", +] + + +@dataclass +class SetupConfig: + teleop: TeleoperatorConfig | None = None + robot: RobotConfig | None = None + + def __post_init__(self): + if bool(self.teleop) == bool(self.robot): + raise ValueError("Choose either a teleop or a robot.") + + self.device = self.robot if self.robot else self.teleop + + +@draccus.wrap() +def setup_motors(cfg: SetupConfig): + if cfg.device.type not in COMPATIBLE_DEVICES: + raise NotImplementedError + + if isinstance(cfg.device, RobotConfig): + device = make_robot_from_config(cfg.device) + else: + device = make_teleoperator_from_config(cfg.device) + + device.setup_motors() + + +if __name__ == "__main__": + setup_motors() diff --git a/lerobot/teleoperate.py b/lerobot/teleoperate.py new file mode 100644 index 000000000..6080dfb40 --- /dev/null +++ b/lerobot/teleoperate.py @@ -0,0 +1,137 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple script to control a robot from teleoperation. + +Example: + +```shell +python -m lerobot.teleoperate \ + --robot.type=so101_follower \ + --robot.port=/dev/tty.usbmodem58760431541 \ + --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ + --robot.id=black \ + --teleop.type=so101_leader \ + --teleop.port=/dev/tty.usbmodem58760431551 \ + --teleop.id=blue \ + --display_data=true +``` +""" + +import logging +import time +from dataclasses import asdict, dataclass +from pprint import pformat + +import draccus +import numpy as np +import rerun as rr + +from lerobot.common.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 +from lerobot.common.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 +from lerobot.common.robots import ( # noqa: F401 + Robot, + RobotConfig, + koch_follower, + make_robot_from_config, + so100_follower, + so101_follower, +) +from lerobot.common.teleoperators import ( + Teleoperator, + TeleoperatorConfig, + make_teleoperator_from_config, +) +from lerobot.common.utils.robot_utils import busy_wait +from lerobot.common.utils.utils import init_logging, move_cursor_up +from lerobot.common.utils.visualization_utils import _init_rerun + +from .common.teleoperators import gamepad, koch_leader, so100_leader, so101_leader # noqa: F401 + + +@dataclass +class TeleoperateConfig: + teleop: TeleoperatorConfig + robot: RobotConfig + # Limit the maximum frames per second. + fps: int = 60 + teleop_time_s: float | None = None + # Display all cameras on screen + display_data: bool = False + + +def teleop_loop( + teleop: Teleoperator, robot: Robot, fps: int, display_data: bool = False, duration: float | None = None +): + display_len = max(len(key) for key in robot.action_features) + start = time.perf_counter() + while True: + loop_start = time.perf_counter() + action = teleop.get_action() + if display_data: + observation = robot.get_observation() + for obs, val in observation.items(): + if isinstance(val, float): + rr.log(f"observation_{obs}", rr.Scalar(val)) + elif isinstance(val, np.ndarray): + rr.log(f"observation_{obs}", rr.Image(val), static=True) + for act, val in action.items(): + if isinstance(val, float): + rr.log(f"action_{act}", rr.Scalar(val)) + + robot.send_action(action) + dt_s = time.perf_counter() - loop_start + busy_wait(1 / fps - dt_s) + + loop_s = time.perf_counter() - loop_start + + print("\n" + "-" * (display_len + 10)) + print(f"{'NAME':<{display_len}} | {'NORM':>7}") + for motor, value in action.items(): + print(f"{motor:<{display_len}} | {value:>7.2f}") + print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)") + + if duration is not None and time.perf_counter() - start >= duration: + return + + move_cursor_up(len(action) + 5) + + +@draccus.wrap() +def teleoperate(cfg: TeleoperateConfig): + init_logging() + logging.info(pformat(asdict(cfg))) + if cfg.display_data: + _init_rerun(session_name="teleoperation") + + teleop = make_teleoperator_from_config(cfg.teleop) + robot = make_robot_from_config(cfg.robot) + + teleop.connect() + robot.connect() + + try: + teleop_loop(teleop, robot, cfg.fps, display_data=cfg.display_data, duration=cfg.teleop_time_s) + except KeyboardInterrupt: + pass + finally: + if cfg.display_data: + rr.rerun_shutdown() + teleop.disconnect() + robot.disconnect() + + +if __name__ == "__main__": + teleoperate() diff --git a/lerobot/templates/lerobot_modelcard_template.md b/lerobot/templates/lerobot_modelcard_template.md new file mode 100644 index 000000000..ca5c182ac --- /dev/null +++ b/lerobot/templates/lerobot_modelcard_template.md @@ -0,0 +1,74 @@ +--- +# For reference on model card metadata, see the spec: https://github.com/huggingface/hub-docs/blob/main/modelcard.md?plain=1 +# Doc / guide: https://huggingface.co/docs/hub/model-cards +{{ card_data }} +--- + +# Model Card for {{ model_name | default("Model ID", true) }} + + + +{% if model_name == "smolvla" %} +[SmolVLA](https://huggingface.co/papers/2506.01844) is a compact, efficient vision-language-action model that achieves competitive performance at reduced computational costs and can be deployed on consumer-grade hardware. +{% elif model_name == "act" %} +[Action Chunking with Transformers (ACT)](https://huggingface.co/papers/2304.13705) is an imitation-learning method that predicts short action chunks instead of single steps. It learns from teleoperated data and often achieves high success rates. +{% elif model_name == "tdmpc" %} +[TD-MPC](https://huggingface.co/papers/2203.04955) combines model-free and model-based approaches to improve sample efficiency and performance in continuous control tasks by using a learned latent dynamics model and terminal value function. +{% elif model_name == "diffusion" %} +[Diffusion Policy](https://huggingface.co/papers/2303.04137) treats visuomotor control as a generative diffusion process, producing smooth, multi-step action trajectories that excel at contact-rich manipulation. +{% elif model_name == "vqbet" %} +[VQ-BET](https://huggingface.co/papers/2403.03181) combines vector-quantised action tokens with Behaviour Transformers to discretise control and achieve data-efficient imitation across diverse skills. +{% elif model_name == "pi0" %} +[Pi0](https://huggingface.co/papers/2410.24164) is a generalist vision-language-action transformer that converts multimodal observations and text instructions into robot actions for zero-shot task transfer. +{% elif model_name == "pi0fast" %} +[Pi0-Fast](https://huggingface.co/papers/2501.09747) is a variant of Pi0 that uses a new tokenization method called FAST, which enables training of an autoregressive vision-language-action policy for high-frequency robotic tasks with improved performance and reduced training time. +{% elif model_name == "sac" %} +[Soft Actor-Critic (SAC)](https://huggingface.co/papers/1801.01290) is an entropy-regularised actor-critic algorithm offering stable, sample-efficient learning in continuous-control environments. +{% elif model_name == "reward_classifier" %} +A reward classifier is a lightweight neural network that scores observations or trajectories for task success, providing a learned reward signal or offline evaluation when explicit rewards are unavailable. +{% else %} +_Model type not recognized — please update this template._ +{% endif %} + +This policy has been trained and pushed to the Hub using [LeRobot](https://github.com/huggingface/lerobot). +See the full documentation at [LeRobot Docs](https://huggingface.co/docs/lerobot/index). + +--- + +## How to Get Started with the Model + +For a complete walkthrough, see the [training guide](https://huggingface.co/docs/lerobot/il_robots#train-a-policy). +Below is the short version on how to train and run inference/eval: + +### Train from scratch + +```bash +python lerobot/scripts/train.py \ + --dataset.repo_id=${HF_USER}/ \ + --policy.type=act \ + --output_dir=outputs/train/ \ + --job_name=lerobot_training \ + --policy.device=cuda \ + --policy.repo_id=${HF_USER}/ + --wandb.enable=true +``` + +*Writes checkpoints to `outputs/train//checkpoints/`.* + +### Evaluate the policy/run inference + +```bash +python -m lerobot.record \ + --robot.type=so100_follower \ + --dataset.repo_id=/eval_ \ + --policy.path=/ \ + --episodes=10 +``` + +Prefix the dataset repo with **eval\_** and supply `--policy.path` pointing to a local or hub checkpoint. + +--- + +## Model Details + +* **License:** {{ license | default("\[More Information Needed]", true) }} diff --git a/media/aloha/follower_rest.webp b/media/aloha/follower_rest.webp deleted file mode 100644 index 03698acd6..000000000 Binary files a/media/aloha/follower_rest.webp and /dev/null differ diff --git a/media/aloha/follower_rotated.webp b/media/aloha/follower_rotated.webp deleted file mode 100644 index 914958bbc..000000000 Binary files a/media/aloha/follower_rotated.webp and /dev/null differ diff --git a/media/aloha/follower_zero.webp b/media/aloha/follower_zero.webp deleted file mode 100644 index c14c516cc..000000000 Binary files a/media/aloha/follower_zero.webp and /dev/null differ diff --git a/media/aloha/leader_rest.webp b/media/aloha/leader_rest.webp deleted file mode 100644 index 821fdf7b3..000000000 Binary files a/media/aloha/leader_rest.webp and /dev/null differ diff --git a/media/aloha/leader_rotated.webp b/media/aloha/leader_rotated.webp deleted file mode 100644 index ed4a3faa7..000000000 Binary files a/media/aloha/leader_rotated.webp and /dev/null differ diff --git a/media/aloha/leader_zero.webp b/media/aloha/leader_zero.webp deleted file mode 100644 index b67cfa773..000000000 Binary files a/media/aloha/leader_zero.webp and /dev/null differ diff --git a/media/koch/follower_rest.webp b/media/koch/follower_rest.webp deleted file mode 100644 index 0a14d074c..000000000 Binary files a/media/koch/follower_rest.webp and /dev/null differ diff --git a/media/koch/follower_rotated.webp b/media/koch/follower_rotated.webp deleted file mode 100644 index 3a91d2490..000000000 Binary files a/media/koch/follower_rotated.webp and /dev/null differ diff --git a/media/koch/follower_zero.webp b/media/koch/follower_zero.webp deleted file mode 100644 index aa107ed3d..000000000 Binary files a/media/koch/follower_zero.webp and /dev/null differ diff --git a/media/koch/leader_rest.webp b/media/koch/leader_rest.webp deleted file mode 100644 index e0454cfd4..000000000 Binary files a/media/koch/leader_rest.webp and /dev/null differ diff --git a/media/koch/leader_rotated.webp b/media/koch/leader_rotated.webp deleted file mode 100644 index 183e4206e..000000000 Binary files a/media/koch/leader_rotated.webp and /dev/null differ diff --git a/media/koch/leader_zero.webp b/media/koch/leader_zero.webp deleted file mode 100644 index f3b885acf..000000000 Binary files a/media/koch/leader_zero.webp and /dev/null differ diff --git a/media/lekiwi/mobile_calib_rest.webp b/media/lekiwi/mobile_calib_rest.webp deleted file mode 100644 index a8f383da0..000000000 Binary files a/media/lekiwi/mobile_calib_rest.webp and /dev/null differ diff --git a/media/lekiwi/mobile_calib_rotated.webp b/media/lekiwi/mobile_calib_rotated.webp deleted file mode 100644 index dab8ed075..000000000 Binary files a/media/lekiwi/mobile_calib_rotated.webp and /dev/null differ diff --git a/media/lekiwi/mobile_calib_zero.webp b/media/lekiwi/mobile_calib_zero.webp deleted file mode 100644 index a067fb2ab..000000000 Binary files a/media/lekiwi/mobile_calib_zero.webp and /dev/null differ diff --git a/media/lekiwi/motor_ids.webp b/media/lekiwi/motor_ids.webp deleted file mode 100644 index 98099c89f..000000000 Binary files a/media/lekiwi/motor_ids.webp and /dev/null differ diff --git a/media/moss/follower_initial.webp b/media/moss/follower_initial.webp deleted file mode 100644 index e7ded16bd..000000000 Binary files a/media/moss/follower_initial.webp and /dev/null differ diff --git a/media/moss/follower_rest.webp b/media/moss/follower_rest.webp deleted file mode 100644 index f0dba18bd..000000000 Binary files a/media/moss/follower_rest.webp and /dev/null differ diff --git a/media/moss/follower_rotated.webp b/media/moss/follower_rotated.webp deleted file mode 100644 index 23d5aa9c1..000000000 Binary files a/media/moss/follower_rotated.webp and /dev/null differ diff --git a/media/moss/follower_zero.webp b/media/moss/follower_zero.webp deleted file mode 100644 index 10ef83704..000000000 Binary files a/media/moss/follower_zero.webp and /dev/null differ diff --git a/media/moss/leader_rest.webp b/media/moss/leader_rest.webp deleted file mode 100644 index cd77d294d..000000000 Binary files a/media/moss/leader_rest.webp and /dev/null differ diff --git a/media/moss/leader_rotated.webp b/media/moss/leader_rotated.webp deleted file mode 100644 index c3426650a..000000000 Binary files a/media/moss/leader_rotated.webp and /dev/null differ diff --git a/media/moss/leader_zero.webp b/media/moss/leader_zero.webp deleted file mode 100644 index d79ed3736..000000000 Binary files a/media/moss/leader_zero.webp and /dev/null differ diff --git a/media/so100/follower_initial.webp b/media/so100/follower_initial.webp deleted file mode 100644 index 7f93a773a..000000000 Binary files a/media/so100/follower_initial.webp and /dev/null differ diff --git a/media/so100/follower_rest.webp b/media/so100/follower_rest.webp deleted file mode 100644 index 971fbc684..000000000 Binary files a/media/so100/follower_rest.webp and /dev/null differ diff --git a/media/so100/follower_rotated.webp b/media/so100/follower_rotated.webp deleted file mode 100644 index b13d7d7d5..000000000 Binary files a/media/so100/follower_rotated.webp and /dev/null differ diff --git a/media/so100/follower_zero.webp b/media/so100/follower_zero.webp deleted file mode 100644 index 411a55545..000000000 Binary files a/media/so100/follower_zero.webp and /dev/null differ diff --git a/media/so100/leader_rest.webp b/media/so100/leader_rest.webp deleted file mode 100644 index 351667778..000000000 Binary files a/media/so100/leader_rest.webp and /dev/null differ diff --git a/media/so100/leader_rotated.webp b/media/so100/leader_rotated.webp deleted file mode 100644 index 1f770f6ce..000000000 Binary files a/media/so100/leader_rotated.webp and /dev/null differ diff --git a/media/so100/leader_zero.webp b/media/so100/leader_zero.webp deleted file mode 100644 index 5f8c235f9..000000000 Binary files a/media/so100/leader_zero.webp and /dev/null differ diff --git a/media/so101/so101-leader.webp b/media/so101/so101-leader.webp new file mode 100644 index 000000000..22ff3a4bc Binary files /dev/null and b/media/so101/so101-leader.webp differ diff --git a/media/so101/so101.webp b/media/so101/so101.webp new file mode 100644 index 000000000..ce65e94bc Binary files /dev/null and b/media/so101/so101.webp differ diff --git a/media/tutorial/koch_v1_1_leader_follower.webp b/media/tutorial/koch_v1_1_leader_follower.webp deleted file mode 100644 index f576a531a..000000000 Binary files a/media/tutorial/koch_v1_1_leader_follower.webp and /dev/null differ diff --git a/media/tutorial/visualize_dataset_html.webp b/media/tutorial/visualize_dataset_html.webp deleted file mode 100644 index e71bc5629..000000000 Binary files a/media/tutorial/visualize_dataset_html.webp and /dev/null differ diff --git a/tests/artifacts/cameras/image_128x128.png b/tests/artifacts/cameras/image_128x128.png new file mode 100644 index 000000000..b117f49f2 --- /dev/null +++ b/tests/artifacts/cameras/image_128x128.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9dc9df05797dc0e7b92edc845caab2e4c37c3cfcabb4ee6339c67212b5baba3b +size 38023 diff --git a/tests/artifacts/cameras/image_160x120.png b/tests/artifacts/cameras/image_160x120.png new file mode 100644 index 000000000..cdc681d18 --- /dev/null +++ b/tests/artifacts/cameras/image_160x120.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7e11af87616b83c1cdb30330e951b91e86b51c64a1326e1ba5b4a3fbcdec1a11 +size 55698 diff --git a/tests/artifacts/cameras/image_320x180.png b/tests/artifacts/cameras/image_320x180.png new file mode 100644 index 000000000..4cfd511a7 --- /dev/null +++ b/tests/artifacts/cameras/image_320x180.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b8840fb643afe903191248703b1f95a57faf5812ecd9978ac502ee939646fdb2 +size 121115 diff --git a/tests/artifacts/cameras/image_480x270.png b/tests/artifacts/cameras/image_480x270.png new file mode 100644 index 000000000..b564d5424 --- /dev/null +++ b/tests/artifacts/cameras/image_480x270.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f79d14daafb1c0cf2fec5d46ee8029a73fe357402fdd31a7cd4a4794d7319a7c +size 260367 diff --git a/tests/artifacts/cameras/test_rs.bag b/tests/artifacts/cameras/test_rs.bag new file mode 100644 index 000000000..1b9662c35 --- /dev/null +++ b/tests/artifacts/cameras/test_rs.bag @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a8d6e64d6cb0e02c94ae125630ee758055bd2e695772c0463a30d63ddc6c5e17 +size 3520862 diff --git a/tests/cameras/mock_cv2.py b/tests/cameras/mock_cv2.py deleted file mode 100644 index eeaf859cc..000000000 --- a/tests/cameras/mock_cv2.py +++ /dev/null @@ -1,101 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from functools import cache - -import numpy as np - -CAP_V4L2 = 200 -CAP_DSHOW = 700 -CAP_AVFOUNDATION = 1200 -CAP_ANY = -1 - -CAP_PROP_FPS = 5 -CAP_PROP_FRAME_WIDTH = 3 -CAP_PROP_FRAME_HEIGHT = 4 -COLOR_RGB2BGR = 4 -COLOR_BGR2RGB = 4 - -ROTATE_90_COUNTERCLOCKWISE = 2 -ROTATE_90_CLOCKWISE = 0 -ROTATE_180 = 1 - - -@cache -def _generate_image(width: int, height: int): - return np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8) - - -def cvtColor(color_image, color_conversion): # noqa: N802 - if color_conversion in [COLOR_RGB2BGR, COLOR_BGR2RGB]: - return color_image[:, :, [2, 1, 0]] - else: - raise NotImplementedError(color_conversion) - - -def rotate(color_image, rotation): - if rotation is None: - return color_image - elif rotation == ROTATE_90_CLOCKWISE: - return np.rot90(color_image, k=1) - elif rotation == ROTATE_180: - return np.rot90(color_image, k=2) - elif rotation == ROTATE_90_COUNTERCLOCKWISE: - return np.rot90(color_image, k=3) - else: - raise NotImplementedError(rotation) - - -class VideoCapture: - def __init__(self, *args, **kwargs): - self._mock_dict = { - CAP_PROP_FPS: 30, - CAP_PROP_FRAME_WIDTH: 640, - CAP_PROP_FRAME_HEIGHT: 480, - } - self._is_opened = True - - def isOpened(self): # noqa: N802 - return self._is_opened - - def set(self, propId: int, value: float) -> bool: # noqa: N803 - if not self._is_opened: - raise RuntimeError("Camera is not opened") - self._mock_dict[propId] = value - return True - - def get(self, propId: int) -> float: # noqa: N803 - if not self._is_opened: - raise RuntimeError("Camera is not opened") - value = self._mock_dict[propId] - if value == 0: - if propId == CAP_PROP_FRAME_HEIGHT: - value = 480 - elif propId == CAP_PROP_FRAME_WIDTH: - value = 640 - return value - - def read(self): - if not self._is_opened: - raise RuntimeError("Camera is not opened") - h = self.get(CAP_PROP_FRAME_HEIGHT) - w = self.get(CAP_PROP_FRAME_WIDTH) - ret = True - return ret, _generate_image(width=w, height=h) - - def release(self): - self._is_opened = False - - def __del__(self): - if self._is_opened: - self.release() diff --git a/tests/cameras/mock_pyrealsense2.py b/tests/cameras/mock_pyrealsense2.py deleted file mode 100644 index c477eb062..000000000 --- a/tests/cameras/mock_pyrealsense2.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import enum - -import numpy as np - - -class stream(enum.Enum): # noqa: N801 - color = 0 - depth = 1 - - -class format(enum.Enum): # noqa: N801 - rgb8 = 0 - z16 = 1 - - -class config: # noqa: N801 - def enable_device(self, device_id: str): - self.device_enabled = device_id - - def enable_stream(self, stream_type: stream, width=None, height=None, color_format=None, fps=None): - self.stream_type = stream_type - # Overwrite default values when possible - self.width = 848 if width is None else width - self.height = 480 if height is None else height - self.color_format = format.rgb8 if color_format is None else color_format - self.fps = 30 if fps is None else fps - - -class RSColorProfile: - def __init__(self, config): - self.config = config - - def fps(self): - return self.config.fps - - def width(self): - return self.config.width - - def height(self): - return self.config.height - - -class RSColorStream: - def __init__(self, config): - self.config = config - - def as_video_stream_profile(self): - return RSColorProfile(self.config) - - -class RSProfile: - def __init__(self, config): - self.config = config - - def get_stream(self, color_format): - del color_format # unused - return RSColorStream(self.config) - - -class pipeline: # noqa: N801 - def __init__(self): - self.started = False - self.config = None - - def start(self, config): - self.started = True - self.config = config - return RSProfile(self.config) - - def stop(self): - if not self.started: - raise RuntimeError("You need to start the camera before stop.") - self.started = False - self.config = None - - def wait_for_frames(self, timeout_ms=50000): - del timeout_ms # unused - return RSFrames(self.config) - - -class RSFrames: - def __init__(self, config): - self.config = config - - def get_color_frame(self): - return RSColorFrame(self.config) - - def get_depth_frame(self): - return RSDepthFrame(self.config) - - -class RSColorFrame: - def __init__(self, config): - self.config = config - - def get_data(self): - data = np.ones((self.config.height, self.config.width, 3), dtype=np.uint8) - # Create a difference between rgb and bgr - data[:, :, 0] = 2 - return data - - -class RSDepthFrame: - def __init__(self, config): - self.config = config - - def get_data(self): - return np.ones((self.config.height, self.config.width), dtype=np.uint16) - - -class RSDevice: - def __init__(self): - pass - - def get_info(self, camera_info) -> str: - del camera_info # unused - # return fake serial number - return "123456789" - - -class context: # noqa: N801 - def __init__(self): - pass - - def query_devices(self): - return [RSDevice()] - - -class camera_info: # noqa: N801 - # fake name - name = "Intel RealSense D435I" - - def __init__(self, serial_number): - del serial_number - pass diff --git a/tests/cameras/test_cameras.py b/tests/cameras/test_cameras.py deleted file mode 100644 index 868358ece..000000000 --- a/tests/cameras/test_cameras.py +++ /dev/null @@ -1,252 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Tests for physical cameras and their mocked versions. -If the physical camera is not connected to the computer, or not working, -the test will be skipped. - -Example of running a specific test: -```bash -pytest -sx tests/test_cameras.py::test_camera -``` - -Example of running test on a real camera connected to the computer: -```bash -pytest -sx 'tests/test_cameras.py::test_camera[opencv-False]' -pytest -sx 'tests/test_cameras.py::test_camera[intelrealsense-False]' -``` - -Example of running test on a mocked version of the camera: -```bash -pytest -sx 'tests/test_cameras.py::test_camera[opencv-True]' -pytest -sx 'tests/test_cameras.py::test_camera[intelrealsense-True]' -``` -""" - -import numpy as np -import pytest - -from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError -from tests.utils import TEST_CAMERA_TYPES, make_camera, require_camera - -# Maximum absolute difference between two consecutive images recorded by a camera. -# This value differs with respect to the camera. -MAX_PIXEL_DIFFERENCE = 25 - - -def compute_max_pixel_difference(first_image, second_image): - return np.abs(first_image.astype(float) - second_image.astype(float)).max() - - -@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES) -@require_camera -def test_camera(request, camera_type, mock): - """Test assumes that `camera.read()` returns the same image when called multiple times in a row. - So the environment should not change (you shouldnt be in front of the camera) and the camera should not be moving. - - Warning: The tests worked for a macbookpro camera, but I am getting assertion error (`np.allclose(color_image, async_color_image)`) - for my iphone camera and my LG monitor camera. - """ - # TODO(rcadene): measure fps in nightly? - # TODO(rcadene): test logs - - if camera_type == "opencv" and not mock: - pytest.skip("TODO(rcadene): fix test for opencv physical camera") - - camera_kwargs = {"camera_type": camera_type, "mock": mock} - - # Test instantiating - camera = make_camera(**camera_kwargs) - - # Test reading, async reading, disconnecting before connecting raises an error - with pytest.raises(RobotDeviceNotConnectedError): - camera.read() - with pytest.raises(RobotDeviceNotConnectedError): - camera.async_read() - with pytest.raises(RobotDeviceNotConnectedError): - camera.disconnect() - - # Test deleting the object without connecting first - del camera - - # Test connecting - camera = make_camera(**camera_kwargs) - camera.connect() - assert camera.is_connected - assert camera.fps is not None - assert camera.capture_width is not None - assert camera.capture_height is not None - - # Test connecting twice raises an error - with pytest.raises(RobotDeviceAlreadyConnectedError): - camera.connect() - - # Test reading from the camera - color_image = camera.read() - assert isinstance(color_image, np.ndarray) - assert color_image.ndim == 3 - h, w, c = color_image.shape - assert c == 3 - assert w > h - - # Test read and async_read outputs similar images - # ...warming up as the first frames can be black - for _ in range(30): - camera.read() - color_image = camera.read() - async_color_image = camera.async_read() - error_msg = ( - "max_pixel_difference between read() and async_read()", - compute_max_pixel_difference(color_image, async_color_image), - ) - # TODO(rcadene): properly set `rtol` - np.testing.assert_allclose( - color_image, async_color_image, rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg - ) - - # Test disconnecting - camera.disconnect() - assert camera.camera is None - assert camera.thread is None - - # Test disconnecting with `__del__` - camera = make_camera(**camera_kwargs) - camera.connect() - del camera - - # Test acquiring a bgr image - camera = make_camera(**camera_kwargs, color_mode="bgr") - camera.connect() - assert camera.color_mode == "bgr" - bgr_color_image = camera.read() - np.testing.assert_allclose( - color_image, bgr_color_image[:, :, [2, 1, 0]], rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg - ) - del camera - - # Test acquiring a rotated image - camera = make_camera(**camera_kwargs) - camera.connect() - ori_color_image = camera.read() - del camera - - for rotation in [None, 90, 180, -90]: - camera = make_camera(**camera_kwargs, rotation=rotation) - camera.connect() - - if mock: - import tests.cameras.mock_cv2 as cv2 - else: - import cv2 - - if rotation is None: - manual_rot_img = ori_color_image - assert camera.rotation is None - elif rotation == 90: - manual_rot_img = np.rot90(color_image, k=1) - assert camera.rotation == cv2.ROTATE_90_CLOCKWISE - elif rotation == 180: - manual_rot_img = np.rot90(color_image, k=2) - assert camera.rotation == cv2.ROTATE_180 - elif rotation == -90: - manual_rot_img = np.rot90(color_image, k=3) - assert camera.rotation == cv2.ROTATE_90_COUNTERCLOCKWISE - - rot_color_image = camera.read() - - np.testing.assert_allclose( - rot_color_image, manual_rot_img, rtol=1e-5, atol=MAX_PIXEL_DIFFERENCE, err_msg=error_msg - ) - del camera - - # TODO(rcadene): Add a test for a camera that doesnt support fps=60 and raises an OSError - # TODO(rcadene): Add a test for a camera that supports fps=60 - - # Test width and height can be set - camera = make_camera(**camera_kwargs, fps=30, width=1280, height=720) - camera.connect() - assert camera.fps == 30 - assert camera.width == 1280 - assert camera.height == 720 - color_image = camera.read() - h, w, c = color_image.shape - assert h == 720 - assert w == 1280 - assert c == 3 - del camera - - # Test not supported width and height raise an error - camera = make_camera(**camera_kwargs, fps=30, width=0, height=0) - with pytest.raises(OSError): - camera.connect() - del camera - - -@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES) -@require_camera -def test_save_images_from_cameras(tmp_path, request, camera_type, mock): - # TODO(rcadene): refactor - if camera_type == "opencv": - from lerobot.common.robot_devices.cameras.opencv import save_images_from_cameras - elif camera_type == "intelrealsense": - from lerobot.common.robot_devices.cameras.intelrealsense import save_images_from_cameras - - # Small `record_time_s` to speedup unit tests - save_images_from_cameras(tmp_path, record_time_s=0.02, mock=mock) - - -@pytest.mark.parametrize("camera_type, mock", TEST_CAMERA_TYPES) -@require_camera -def test_camera_rotation(request, camera_type, mock): - config_kwargs = {"camera_type": camera_type, "mock": mock, "width": 640, "height": 480, "fps": 30} - - # No rotation. - camera = make_camera(**config_kwargs, rotation=None) - camera.connect() - assert camera.capture_width == 640 - assert camera.capture_height == 480 - assert camera.width == 640 - assert camera.height == 480 - no_rot_img = camera.read() - h, w, c = no_rot_img.shape - assert h == 480 and w == 640 and c == 3 - camera.disconnect() - - # Rotation = 90 (clockwise). - camera = make_camera(**config_kwargs, rotation=90) - camera.connect() - # With a 90° rotation, we expect the metadata dimensions to be swapped. - assert camera.capture_width == 640 - assert camera.capture_height == 480 - assert camera.width == 480 - assert camera.height == 640 - import cv2 - - assert camera.rotation == cv2.ROTATE_90_CLOCKWISE - rot_img = camera.read() - h, w, c = rot_img.shape - assert h == 640 and w == 480 and c == 3 - camera.disconnect() - - # Rotation = 180. - camera = make_camera(**config_kwargs, rotation=None) - camera.connect() - assert camera.capture_width == 640 - assert camera.capture_height == 480 - assert camera.width == 640 - assert camera.height == 480 - no_rot_img = camera.read() - h, w, c = no_rot_img.shape - assert h == 480 and w == 640 and c == 3 - camera.disconnect() diff --git a/tests/cameras/test_opencv.py b/tests/cameras/test_opencv.py new file mode 100644 index 000000000..7ba04b261 --- /dev/null +++ b/tests/cameras/test_opencv.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example of running a specific test: +# ```bash +# pytest tests/cameras/test_opencv.py::test_connect +# ``` + +from pathlib import Path + +import numpy as np +import pytest + +from lerobot.common.cameras.configs import Cv2Rotation +from lerobot.common.cameras.opencv import OpenCVCamera, OpenCVCameraConfig +from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +# NOTE(Steven): more tests + assertions? +TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "cameras" +DEFAULT_PNG_FILE_PATH = TEST_ARTIFACTS_DIR / "image_160x120.png" +TEST_IMAGE_SIZES = ["128x128", "160x120", "320x180", "480x270"] +TEST_IMAGE_PATHS = [TEST_ARTIFACTS_DIR / f"image_{size}.png" for size in TEST_IMAGE_SIZES] + + +def test_abc_implementation(): + """Instantiation should raise an error if the class doesn't implement abstract methods/properties.""" + config = OpenCVCameraConfig(index_or_path=0) + + _ = OpenCVCamera(config) + + +def test_connect(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) + camera = OpenCVCamera(config) + + camera.connect(warmup=False) + + assert camera.is_connected + + +def test_connect_already_connected(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) + camera = OpenCVCamera(config) + camera.connect(warmup=False) + + with pytest.raises(DeviceAlreadyConnectedError): + camera.connect(warmup=False) + + +def test_connect_invalid_camera_path(): + config = OpenCVCameraConfig(index_or_path="nonexistent/camera.png") + camera = OpenCVCamera(config) + + with pytest.raises(ConnectionError): + camera.connect(warmup=False) + + +def test_invalid_width_connect(): + config = OpenCVCameraConfig( + index_or_path=DEFAULT_PNG_FILE_PATH, + width=99999, # Invalid width to trigger error + height=480, + ) + camera = OpenCVCamera(config) + + with pytest.raises(RuntimeError): + camera.connect(warmup=False) + + +@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES) +def test_read(index_or_path): + config = OpenCVCameraConfig(index_or_path=index_or_path) + camera = OpenCVCamera(config) + camera.connect(warmup=False) + + img = camera.read() + + assert isinstance(img, np.ndarray) + + +def test_read_before_connect(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) + camera = OpenCVCamera(config) + + with pytest.raises(DeviceNotConnectedError): + _ = camera.read() + + +def test_disconnect(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) + camera = OpenCVCamera(config) + camera.connect(warmup=False) + + camera.disconnect() + + assert not camera.is_connected + + +def test_disconnect_before_connect(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) + camera = OpenCVCamera(config) + + with pytest.raises(DeviceNotConnectedError): + _ = camera.disconnect() + + +@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES) +def test_async_read(index_or_path): + config = OpenCVCameraConfig(index_or_path=index_or_path) + camera = OpenCVCamera(config) + camera.connect(warmup=False) + + try: + img = camera.async_read() + + assert camera.thread is not None + assert camera.thread.is_alive() + assert isinstance(img, np.ndarray) + finally: + if camera.is_connected: + camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends + + +def test_async_read_timeout(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) + camera = OpenCVCamera(config) + camera.connect(warmup=False) + + try: + with pytest.raises(TimeoutError): + camera.async_read(timeout_ms=0) + finally: + if camera.is_connected: + camera.disconnect() + + +def test_async_read_before_connect(): + config = OpenCVCameraConfig(index_or_path=DEFAULT_PNG_FILE_PATH) + camera = OpenCVCamera(config) + + with pytest.raises(DeviceNotConnectedError): + _ = camera.async_read() + + +@pytest.mark.parametrize("index_or_path", TEST_IMAGE_PATHS, ids=TEST_IMAGE_SIZES) +@pytest.mark.parametrize( + "rotation", + [ + Cv2Rotation.NO_ROTATION, + Cv2Rotation.ROTATE_90, + Cv2Rotation.ROTATE_180, + Cv2Rotation.ROTATE_270, + ], + ids=["no_rot", "rot90", "rot180", "rot270"], +) +def test_rotation(rotation, index_or_path): + filename = Path(index_or_path).name + dimensions = filename.split("_")[-1].split(".")[0] # Assumes filenames format (_wxh.png) + original_width, original_height = map(int, dimensions.split("x")) + + config = OpenCVCameraConfig(index_or_path=index_or_path, rotation=rotation) + camera = OpenCVCamera(config) + camera.connect(warmup=False) + + img = camera.read() + assert isinstance(img, np.ndarray) + + if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270): + assert camera.width == original_height + assert camera.height == original_width + assert img.shape[:2] == (original_width, original_height) + else: + assert camera.width == original_width + assert camera.height == original_height + assert img.shape[:2] == (original_height, original_width) diff --git a/tests/cameras/test_realsense.py b/tests/cameras/test_realsense.py new file mode 100644 index 000000000..5fb1767fe --- /dev/null +++ b/tests/cameras/test_realsense.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Example of running a specific test: +# ```bash +# pytest tests/cameras/test_opencv.py::test_connect +# ``` + +from pathlib import Path +from unittest.mock import patch + +import numpy as np +import pytest + +from lerobot.common.cameras.configs import Cv2Rotation +from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError + +pytest.importorskip("pyrealsense2") + +from lerobot.common.cameras.realsense import RealSenseCamera, RealSenseCameraConfig + +TEST_ARTIFACTS_DIR = Path(__file__).parent.parent / "artifacts" / "cameras" +BAG_FILE_PATH = TEST_ARTIFACTS_DIR / "test_rs.bag" + +# NOTE(Steven): For some reason these tests take ~20sec in macOS but only ~2sec in Linux. + + +def mock_rs_config_enable_device_from_file(rs_config_instance, _sn): + return rs_config_instance.enable_device_from_file(str(BAG_FILE_PATH), repeat_playback=True) + + +def mock_rs_config_enable_device_bad_file(rs_config_instance, _sn): + return rs_config_instance.enable_device_from_file("non_existent_file.bag", repeat_playback=True) + + +@pytest.fixture(name="patch_realsense", autouse=True) +def fixture_patch_realsense(): + """Automatically mock pyrealsense2.config.enable_device for all tests.""" + with patch( + "pyrealsense2.config.enable_device", side_effect=mock_rs_config_enable_device_from_file + ) as mock: + yield mock + + +def test_abc_implementation(): + """Instantiation should raise an error if the class doesn't implement abstract methods/properties.""" + config = RealSenseCameraConfig(serial_number_or_name="042") + _ = RealSenseCamera(config) + + +def test_connect(): + config = RealSenseCameraConfig(serial_number_or_name="042") + camera = RealSenseCamera(config) + + camera.connect(warmup=False) + assert camera.is_connected + + +def test_connect_already_connected(): + config = RealSenseCameraConfig(serial_number_or_name="042") + camera = RealSenseCamera(config) + camera.connect(warmup=False) + + with pytest.raises(DeviceAlreadyConnectedError): + camera.connect(warmup=False) + + +def test_connect_invalid_camera_path(patch_realsense): + patch_realsense.side_effect = mock_rs_config_enable_device_bad_file + config = RealSenseCameraConfig(serial_number_or_name="042") + camera = RealSenseCamera(config) + + with pytest.raises(ConnectionError): + camera.connect(warmup=False) + + +def test_invalid_width_connect(): + config = RealSenseCameraConfig(serial_number_or_name="042", width=99999, height=480, fps=30) + camera = RealSenseCamera(config) + + with pytest.raises(ConnectionError): + camera.connect(warmup=False) + + +def test_read(): + config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30) + camera = RealSenseCamera(config) + camera.connect(warmup=False) + + img = camera.read() + assert isinstance(img, np.ndarray) + + +def test_read_depth(): + config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30, use_depth=True) + camera = RealSenseCamera(config) + camera.connect(warmup=False) + + img = camera.read_depth(timeout_ms=1000) # NOTE(Steven): Reading depth takes longer + assert isinstance(img, np.ndarray) + + +def test_read_before_connect(): + config = RealSenseCameraConfig(serial_number_or_name="042") + camera = RealSenseCamera(config) + + with pytest.raises(DeviceNotConnectedError): + _ = camera.read() + + +def test_disconnect(): + config = RealSenseCameraConfig(serial_number_or_name="042") + camera = RealSenseCamera(config) + camera.connect(warmup=False) + + camera.disconnect() + + assert not camera.is_connected + + +def test_disconnect_before_connect(): + config = RealSenseCameraConfig(serial_number_or_name="042") + camera = RealSenseCamera(config) + + with pytest.raises(DeviceNotConnectedError): + camera.disconnect() + + +def test_async_read(): + config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30) + camera = RealSenseCamera(config) + camera.connect(warmup=False) + + try: + img = camera.async_read() + + assert camera.thread is not None + assert camera.thread.is_alive() + assert isinstance(img, np.ndarray) + finally: + if camera.is_connected: + camera.disconnect() # To stop/join the thread. Otherwise get warnings when the test ends + + +def test_async_read_timeout(): + config = RealSenseCameraConfig(serial_number_or_name="042", width=640, height=480, fps=30) + camera = RealSenseCamera(config) + camera.connect(warmup=False) + + try: + with pytest.raises(TimeoutError): + camera.async_read(timeout_ms=0) + finally: + if camera.is_connected: + camera.disconnect() + + +def test_async_read_before_connect(): + config = RealSenseCameraConfig(serial_number_or_name="042") + camera = RealSenseCamera(config) + + with pytest.raises(DeviceNotConnectedError): + _ = camera.async_read() + + +@pytest.mark.parametrize( + "rotation", + [ + Cv2Rotation.NO_ROTATION, + Cv2Rotation.ROTATE_90, + Cv2Rotation.ROTATE_180, + Cv2Rotation.ROTATE_270, + ], + ids=["no_rot", "rot90", "rot180", "rot270"], +) +def test_rotation(rotation): + config = RealSenseCameraConfig(serial_number_or_name="042", rotation=rotation) + camera = RealSenseCamera(config) + camera.connect(warmup=False) + + img = camera.read() + assert isinstance(img, np.ndarray) + + if rotation in (Cv2Rotation.ROTATE_90, Cv2Rotation.ROTATE_270): + assert camera.width == 480 + assert camera.height == 640 + assert img.shape[:2] == (640, 480) + else: + assert camera.width == 640 + assert camera.height == 480 + assert img.shape[:2] == (480, 640) diff --git a/tests/datasets/test_image_transforms.py b/tests/datasets/test_image_transforms.py index 352aba999..146a4dcd4 100644 --- a/tests/datasets/test_image_transforms.py +++ b/tests/datasets/test_image_transforms.py @@ -16,6 +16,7 @@ import pytest import torch +from packaging import version from safetensors.torch import load_file from torchvision.transforms import v2 from torchvision.transforms.v2 import functional as F # noqa: N812 @@ -253,7 +254,14 @@ def test_backward_compatibility_single_transforms( @require_x86_64_kernel +@pytest.mark.skipif( + version.parse(torch.__version__) < version.parse("2.7.0"), + reason="Test artifacts were generated with PyTorch >= 2.7.0 which has different multinomial behavior", +) def test_backward_compatibility_default_config(img_tensor, default_transforms): + # NOTE: PyTorch versions have different randomness, it might break this test. + # See this PR: https://github.com/huggingface/lerobot/pull/1127. + cfg = ImageTransformsConfig(enable=True) default_tf = ImageTransforms(cfg) diff --git a/tests/mocks/mock_dynamixel.py b/tests/mocks/mock_dynamixel.py new file mode 100644 index 000000000..d446bf272 --- /dev/null +++ b/tests/mocks/mock_dynamixel.py @@ -0,0 +1,580 @@ +import abc +from typing import Callable + +import dynamixel_sdk as dxl +import serial +from mock_serial.mock_serial import MockSerial + +from lerobot.common.motors.dynamixel.dynamixel import _split_into_byte_chunks + +from .mock_serial_patch import WaitableStub + +# https://emanual.robotis.com/docs/en/dxl/crc/ +DXL_CRC_TABLE = [ + 0x0000, 0x8005, 0x800F, 0x000A, 0x801B, 0x001E, 0x0014, 0x8011, + 0x8033, 0x0036, 0x003C, 0x8039, 0x0028, 0x802D, 0x8027, 0x0022, + 0x8063, 0x0066, 0x006C, 0x8069, 0x0078, 0x807D, 0x8077, 0x0072, + 0x0050, 0x8055, 0x805F, 0x005A, 0x804B, 0x004E, 0x0044, 0x8041, + 0x80C3, 0x00C6, 0x00CC, 0x80C9, 0x00D8, 0x80DD, 0x80D7, 0x00D2, + 0x00F0, 0x80F5, 0x80FF, 0x00FA, 0x80EB, 0x00EE, 0x00E4, 0x80E1, + 0x00A0, 0x80A5, 0x80AF, 0x00AA, 0x80BB, 0x00BE, 0x00B4, 0x80B1, + 0x8093, 0x0096, 0x009C, 0x8099, 0x0088, 0x808D, 0x8087, 0x0082, + 0x8183, 0x0186, 0x018C, 0x8189, 0x0198, 0x819D, 0x8197, 0x0192, + 0x01B0, 0x81B5, 0x81BF, 0x01BA, 0x81AB, 0x01AE, 0x01A4, 0x81A1, + 0x01E0, 0x81E5, 0x81EF, 0x01EA, 0x81FB, 0x01FE, 0x01F4, 0x81F1, + 0x81D3, 0x01D6, 0x01DC, 0x81D9, 0x01C8, 0x81CD, 0x81C7, 0x01C2, + 0x0140, 0x8145, 0x814F, 0x014A, 0x815B, 0x015E, 0x0154, 0x8151, + 0x8173, 0x0176, 0x017C, 0x8179, 0x0168, 0x816D, 0x8167, 0x0162, + 0x8123, 0x0126, 0x012C, 0x8129, 0x0138, 0x813D, 0x8137, 0x0132, + 0x0110, 0x8115, 0x811F, 0x011A, 0x810B, 0x010E, 0x0104, 0x8101, + 0x8303, 0x0306, 0x030C, 0x8309, 0x0318, 0x831D, 0x8317, 0x0312, + 0x0330, 0x8335, 0x833F, 0x033A, 0x832B, 0x032E, 0x0324, 0x8321, + 0x0360, 0x8365, 0x836F, 0x036A, 0x837B, 0x037E, 0x0374, 0x8371, + 0x8353, 0x0356, 0x035C, 0x8359, 0x0348, 0x834D, 0x8347, 0x0342, + 0x03C0, 0x83C5, 0x83CF, 0x03CA, 0x83DB, 0x03DE, 0x03D4, 0x83D1, + 0x83F3, 0x03F6, 0x03FC, 0x83F9, 0x03E8, 0x83ED, 0x83E7, 0x03E2, + 0x83A3, 0x03A6, 0x03AC, 0x83A9, 0x03B8, 0x83BD, 0x83B7, 0x03B2, + 0x0390, 0x8395, 0x839F, 0x039A, 0x838B, 0x038E, 0x0384, 0x8381, + 0x0280, 0x8285, 0x828F, 0x028A, 0x829B, 0x029E, 0x0294, 0x8291, + 0x82B3, 0x02B6, 0x02BC, 0x82B9, 0x02A8, 0x82AD, 0x82A7, 0x02A2, + 0x82E3, 0x02E6, 0x02EC, 0x82E9, 0x02F8, 0x82FD, 0x82F7, 0x02F2, + 0x02D0, 0x82D5, 0x82DF, 0x02DA, 0x82CB, 0x02CE, 0x02C4, 0x82C1, + 0x8243, 0x0246, 0x024C, 0x8249, 0x0258, 0x825D, 0x8257, 0x0252, + 0x0270, 0x8275, 0x827F, 0x027A, 0x826B, 0x026E, 0x0264, 0x8261, + 0x0220, 0x8225, 0x822F, 0x022A, 0x823B, 0x023E, 0x0234, 0x8231, + 0x8213, 0x0216, 0x021C, 0x8219, 0x0208, 0x820D, 0x8207, 0x0202 +] # fmt: skip + + +class MockDynamixelPacketv2(abc.ABC): + @classmethod + def build(cls, dxl_id: int, params: list[int], length: int, *args, **kwargs) -> bytes: + packet = cls._build(dxl_id, params, length, *args, **kwargs) + packet = cls._add_stuffing(packet) + packet = cls._add_crc(packet) + return bytes(packet) + + @abc.abstractclassmethod + def _build(cls, dxl_id: int, params: list[int], length: int, *args, **kwargs) -> list[int]: + pass + + @staticmethod + def _add_stuffing(packet: list[int]) -> list[int]: + """ + Byte stuffing is a method of adding additional data to generated instruction packets to ensure that + the packets are processed successfully. When the byte pattern "0xFF 0xFF 0xFD" appears in a packet, + byte stuffing adds 0xFD to the end of the pattern to convert it to “0xFF 0xFF 0xFD 0xFD” to ensure + that it is not interpreted as the header at the start of another packet. + + Source: https://emanual.robotis.com/docs/en/dxl/protocol2/#transmission-process + + Args: + packet (list[int]): The raw packet without stuffing. + + Returns: + list[int]: The packet stuffed if it contained a "0xFF 0xFF 0xFD" byte sequence in its data bytes. + """ + packet_length_in = dxl.DXL_MAKEWORD(packet[dxl.PKT_LENGTH_L], packet[dxl.PKT_LENGTH_H]) + packet_length_out = packet_length_in + + temp = [0] * dxl.TXPACKET_MAX_LEN + + # FF FF FD XX ID LEN_L LEN_H + temp[dxl.PKT_HEADER0 : dxl.PKT_HEADER0 + dxl.PKT_LENGTH_H + 1] = packet[ + dxl.PKT_HEADER0 : dxl.PKT_HEADER0 + dxl.PKT_LENGTH_H + 1 + ] + + index = dxl.PKT_INSTRUCTION + + for i in range(0, packet_length_in - 2): # except CRC + temp[index] = packet[i + dxl.PKT_INSTRUCTION] + index = index + 1 + if ( + packet[i + dxl.PKT_INSTRUCTION] == 0xFD + and packet[i + dxl.PKT_INSTRUCTION - 1] == 0xFF + and packet[i + dxl.PKT_INSTRUCTION - 2] == 0xFF + ): + # FF FF FD + temp[index] = 0xFD + index = index + 1 + packet_length_out = packet_length_out + 1 + + temp[index] = packet[dxl.PKT_INSTRUCTION + packet_length_in - 2] + temp[index + 1] = packet[dxl.PKT_INSTRUCTION + packet_length_in - 1] + index = index + 2 + + if packet_length_in != packet_length_out: + packet = [0] * index + + packet[0:index] = temp[0:index] + + packet[dxl.PKT_LENGTH_L] = dxl.DXL_LOBYTE(packet_length_out) + packet[dxl.PKT_LENGTH_H] = dxl.DXL_HIBYTE(packet_length_out) + + return packet + + @staticmethod + def _add_crc(packet: list[int]) -> list[int]: + """Computes and add CRC to the packet. + + https://emanual.robotis.com/docs/en/dxl/crc/ + https://en.wikipedia.org/wiki/Cyclic_redundancy_check + + Args: + packet (list[int]): The raw packet without CRC (but with placeholders for it). + + Returns: + list[int]: The raw packet with a valid CRC. + """ + crc = 0 + for j in range(len(packet) - 2): + i = ((crc >> 8) ^ packet[j]) & 0xFF + crc = ((crc << 8) ^ DXL_CRC_TABLE[i]) & 0xFFFF + + packet[-2] = dxl.DXL_LOBYTE(crc) + packet[-1] = dxl.DXL_HIBYTE(crc) + + return packet + + +class MockInstructionPacket(MockDynamixelPacketv2): + """ + Helper class to build valid Dynamixel Protocol 2.0 Instruction Packets. + + Protocol 2.0 Instruction Packet structure + https://emanual.robotis.com/docs/en/dxl/protocol2/#instruction-packet + + | Header | Packet ID | Length | Instruction | Params | CRC | + | ------------------- | --------- | ----------- | ----------- | ----------------- | ----------- | + | 0xFF 0xFF 0xFD 0x00 | ID | Len_L Len_H | Instr | Param 1 … Param N | CRC_L CRC_H | + + """ + + @classmethod + def _build(cls, dxl_id: int, params: list[int], length: int, instruction: int) -> list[int]: + length = len(params) + 3 + return [ + 0xFF, 0xFF, 0xFD, 0x00, # header + dxl_id, # servo id + dxl.DXL_LOBYTE(length), # length_l + dxl.DXL_HIBYTE(length), # length_h + instruction, # instruction type + *params, # data bytes + 0x00, 0x00 # placeholder for CRC + ] # fmt: skip + + @classmethod + def ping( + cls, + dxl_id: int, + ) -> bytes: + """ + Builds a "Ping" broadcast instruction. + https://emanual.robotis.com/docs/en/dxl/protocol2/#ping-0x01 + + No parameters required. + """ + return cls.build(dxl_id=dxl_id, params=[], length=3, instruction=dxl.INST_PING) + + @classmethod + def read( + cls, + dxl_id: int, + start_address: int, + data_length: int, + ) -> bytes: + """ + Builds a "Read" instruction. + https://emanual.robotis.com/docs/en/dxl/protocol2/#read-0x02 + + The parameters for Read (Protocol 2.0) are: + param[0] = start_address L + param[1] = start_address H + param[2] = data_length L + param[3] = data_length H + + And 'length' = data_length + 5, where: + +1 is for instruction byte, + +2 is for the length bytes, + +2 is for the CRC at the end. + """ + params = [ + dxl.DXL_LOBYTE(start_address), + dxl.DXL_HIBYTE(start_address), + dxl.DXL_LOBYTE(data_length), + dxl.DXL_HIBYTE(data_length), + ] + length = len(params) + 3 + # length = data_length + 5 + return cls.build(dxl_id=dxl_id, params=params, length=length, instruction=dxl.INST_READ) + + @classmethod + def write( + cls, + dxl_id: int, + value: int, + start_address: int, + data_length: int, + ) -> bytes: + """ + Builds a "Write" instruction. + https://emanual.robotis.com/docs/en/dxl/protocol2/#write-0x03 + + The parameters for Write (Protocol 2.0) are: + param[0] = start_address L + param[1] = start_address H + param[2] = 1st Byte + param[3] = 2nd Byte + ... + param[1+X] = X-th Byte + + And 'length' = data_length + 5, where: + +1 is for instruction byte, + +2 is for the length bytes, + +2 is for the CRC at the end. + """ + data = _split_into_byte_chunks(value, data_length) + params = [ + dxl.DXL_LOBYTE(start_address), + dxl.DXL_HIBYTE(start_address), + *data, + ] + length = data_length + 5 + return cls.build(dxl_id=dxl_id, params=params, length=length, instruction=dxl.INST_WRITE) + + @classmethod + def sync_read( + cls, + dxl_ids: list[int], + start_address: int, + data_length: int, + ) -> bytes: + """ + Builds a "Sync_Read" broadcast instruction. + https://emanual.robotis.com/docs/en/dxl/protocol2/#sync-read-0x82 + + The parameters for Sync_Read (Protocol 2.0) are: + param[0] = start_address L + param[1] = start_address H + param[2] = data_length L + param[3] = data_length H + param[4+] = motor IDs to read from + + And 'length' = (number_of_params + 7), where: + +1 is for instruction byte, + +2 is for the address bytes, + +2 is for the length bytes, + +2 is for the CRC at the end. + """ + params = [ + dxl.DXL_LOBYTE(start_address), + dxl.DXL_HIBYTE(start_address), + dxl.DXL_LOBYTE(data_length), + dxl.DXL_HIBYTE(data_length), + *dxl_ids, + ] + length = len(dxl_ids) + 7 + return cls.build( + dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruction=dxl.INST_SYNC_READ + ) + + @classmethod + def sync_write( + cls, + ids_values: dict[int, int], + start_address: int, + data_length: int, + ) -> bytes: + """ + Builds a "Sync_Write" broadcast instruction. + https://emanual.robotis.com/docs/en/dxl/protocol2/#sync-write-0x83 + + The parameters for Sync_Write (Protocol 2.0) are: + param[0] = start_address L + param[1] = start_address H + param[2] = data_length L + param[3] = data_length H + param[5] = [1st motor] ID + param[5+1] = [1st motor] 1st Byte + param[5+2] = [1st motor] 2nd Byte + ... + param[5+X] = [1st motor] X-th Byte + param[6] = [2nd motor] ID + param[6+1] = [2nd motor] 1st Byte + param[6+2] = [2nd motor] 2nd Byte + ... + param[6+X] = [2nd motor] X-th Byte + + And 'length' = ((number_of_params * 1 + data_length) + 7), where: + +1 is for instruction byte, + +2 is for the address bytes, + +2 is for the length bytes, + +2 is for the CRC at the end. + """ + data = [] + for id_, value in ids_values.items(): + split_value = _split_into_byte_chunks(value, data_length) + data += [id_, *split_value] + params = [ + dxl.DXL_LOBYTE(start_address), + dxl.DXL_HIBYTE(start_address), + dxl.DXL_LOBYTE(data_length), + dxl.DXL_HIBYTE(data_length), + *data, + ] + length = len(ids_values) * (1 + data_length) + 7 + return cls.build( + dxl_id=dxl.BROADCAST_ID, params=params, length=length, instruction=dxl.INST_SYNC_WRITE + ) + + +class MockStatusPacket(MockDynamixelPacketv2): + """ + Helper class to build valid Dynamixel Protocol 2.0 Status Packets. + + Protocol 2.0 Status Packet structure + https://emanual.robotis.com/docs/en/dxl/protocol2/#status-packet + + | Header | Packet ID | Length | Instruction | Error | Params | CRC | + | ------------------- | --------- | ----------- | ----------- | ----- | ----------------- | ----------- | + | 0xFF 0xFF 0xFD 0x00 | ID | Len_L Len_H | 0x55 | Err | Param 1 … Param N | CRC_L CRC_H | + """ + + @classmethod + def _build(cls, dxl_id: int, params: list[int], length: int, error: int = 0) -> list[int]: + return [ + 0xFF, 0xFF, 0xFD, 0x00, # header + dxl_id, # servo id + dxl.DXL_LOBYTE(length), # length_l + dxl.DXL_HIBYTE(length), # length_h + 0x55, # instruction = 'status' + error, # error + *params, # data bytes + 0x00, 0x00 # placeholder for CRC + ] # fmt: skip + + @classmethod + def ping(cls, dxl_id: int, model_nb: int = 1190, firm_ver: int = 50, error: int = 0) -> bytes: + """ + Builds a 'Ping' status packet. + https://emanual.robotis.com/docs/en/dxl/protocol2/#ping-0x01 + + Args: + dxl_id (int): ID of the servo responding. + model_nb (int, optional): Desired 'model number' to be returned in the packet. Defaults to 1190 + which corresponds to a XL330-M077-T. + firm_ver (int, optional): Desired 'firmware version' to be returned in the packet. + Defaults to 50. + + Returns: + bytes: The raw 'Ping' status packet ready to be sent through serial. + """ + params = [dxl.DXL_LOBYTE(model_nb), dxl.DXL_HIBYTE(model_nb), firm_ver] + length = 7 + return cls.build(dxl_id, params=params, length=length, error=error) + + @classmethod + def read(cls, dxl_id: int, value: int, param_length: int, error: int = 0) -> bytes: + """ + Builds a 'Read' status packet (also works for 'Sync Read') + https://emanual.robotis.com/docs/en/dxl/protocol2/#read-0x02 + https://emanual.robotis.com/docs/en/dxl/protocol2/#sync-read-0x82 + + Args: + dxl_id (int): ID of the servo responding. + value (int): Desired value to be returned in the packet. + param_length (int): The address length as reported in the control table. + + Returns: + bytes: The raw 'Present_Position' status packet ready to be sent through serial. + """ + params = _split_into_byte_chunks(value, param_length) + length = param_length + 4 + return cls.build(dxl_id, params=params, length=length, error=error) + + +class MockPortHandler(dxl.PortHandler): + """ + This class overwrite the 'setupPort' method of the Dynamixel PortHandler because it can specify + baudrates that are not supported with a serial port on MacOS. + """ + + def setupPort(self, cflag_baud): # noqa: N802 + if self.is_open: + self.closePort() + + self.ser = serial.Serial( + port=self.port_name, + # baudrate=self.baudrate, <- This will fail on MacOS + # parity = serial.PARITY_ODD, + # stopbits = serial.STOPBITS_TWO, + bytesize=serial.EIGHTBITS, + timeout=0, + ) + self.is_open = True + self.ser.reset_input_buffer() + self.tx_time_per_byte = (1000.0 / self.baudrate) * 10.0 + + return True + + +class MockMotors(MockSerial): + """ + This class will simulate physical motors by responding with valid status packets upon receiving some + instruction packets. It is meant to test MotorsBus classes. + """ + + def __init__(self): + super().__init__() + + @property + def stubs(self) -> dict[str, WaitableStub]: + return super().stubs + + def stub(self, *, name=None, **kwargs): + new_stub = WaitableStub(**kwargs) + self._MockSerial__stubs[name or new_stub.receive_bytes] = new_stub + return new_stub + + def build_broadcast_ping_stub( + self, ids_models: dict[int, list[int]] | None = None, num_invalid_try: int = 0 + ) -> str: + ping_request = MockInstructionPacket.ping(dxl.BROADCAST_ID) + return_packets = b"".join(MockStatusPacket.ping(id_, model) for id_, model in ids_models.items()) + ping_response = self._build_send_fn(return_packets, num_invalid_try) + + stub_name = "Ping_" + "_".join([str(id_) for id_ in ids_models]) + self.stub( + name=stub_name, + receive_bytes=ping_request, + send_fn=ping_response, + ) + return stub_name + + def build_ping_stub( + self, dxl_id: int, model_nb: int, firm_ver: int = 50, num_invalid_try: int = 0, error: int = 0 + ) -> str: + ping_request = MockInstructionPacket.ping(dxl_id) + return_packet = MockStatusPacket.ping(dxl_id, model_nb, firm_ver, error) + ping_response = self._build_send_fn(return_packet, num_invalid_try) + stub_name = f"Ping_{dxl_id}" + self.stub( + name=stub_name, + receive_bytes=ping_request, + send_fn=ping_response, + ) + return stub_name + + def build_read_stub( + self, + address: int, + length: int, + dxl_id: int, + value: int, + reply: bool = True, + error: int = 0, + num_invalid_try: int = 0, + ) -> str: + read_request = MockInstructionPacket.read(dxl_id, address, length) + return_packet = MockStatusPacket.read(dxl_id, value, length, error) if reply else b"" + read_response = self._build_send_fn(return_packet, num_invalid_try) + stub_name = f"Read_{address}_{length}_{dxl_id}_{value}_{error}" + self.stub( + name=stub_name, + receive_bytes=read_request, + send_fn=read_response, + ) + return stub_name + + def build_write_stub( + self, + address: int, + length: int, + dxl_id: int, + value: int, + reply: bool = True, + error: int = 0, + num_invalid_try: int = 0, + ) -> str: + sync_read_request = MockInstructionPacket.write(dxl_id, value, address, length) + return_packet = MockStatusPacket.build(dxl_id, params=[], length=4, error=error) if reply else b"" + stub_name = f"Write_{address}_{length}_{dxl_id}" + self.stub( + name=stub_name, + receive_bytes=sync_read_request, + send_fn=self._build_send_fn(return_packet, num_invalid_try), + ) + return stub_name + + def build_sync_read_stub( + self, + address: int, + length: int, + ids_values: dict[int, int], + reply: bool = True, + num_invalid_try: int = 0, + ) -> str: + sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length) + return_packets = ( + b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items()) + if reply + else b"" + ) + sync_read_response = self._build_send_fn(return_packets, num_invalid_try) + stub_name = f"Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values]) + self.stub( + name=stub_name, + receive_bytes=sync_read_request, + send_fn=sync_read_response, + ) + return stub_name + + def build_sequential_sync_read_stub( + self, address: int, length: int, ids_values: dict[int, list[int]] | None = None + ) -> str: + sequence_length = len(next(iter(ids_values.values()))) + assert all(len(positions) == sequence_length for positions in ids_values.values()) + sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length) + sequential_packets = [] + for count in range(sequence_length): + return_packets = b"".join( + MockStatusPacket.read(id_, positions[count], length) for id_, positions in ids_values.items() + ) + sequential_packets.append(return_packets) + + sync_read_response = self._build_sequential_send_fn(sequential_packets) + stub_name = f"Seq_Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values]) + self.stub( + name=stub_name, + receive_bytes=sync_read_request, + send_fn=sync_read_response, + ) + return stub_name + + def build_sync_write_stub( + self, address: int, length: int, ids_values: dict[int, int], num_invalid_try: int = 0 + ) -> str: + sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length) + stub_name = f"Sync_Write_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values]) + self.stub( + name=stub_name, + receive_bytes=sync_read_request, + send_fn=self._build_send_fn(b"", num_invalid_try), + ) + return stub_name + + @staticmethod + def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]: + def send_fn(_call_count: int) -> bytes: + if num_invalid_try >= _call_count: + return b"" + return packet + + return send_fn + + @staticmethod + def _build_sequential_send_fn(packets: list[bytes]) -> Callable[[int], bytes]: + def send_fn(_call_count: int) -> bytes: + return packets[_call_count - 1] + + return send_fn diff --git a/tests/mocks/mock_feetech.py b/tests/mocks/mock_feetech.py new file mode 100644 index 000000000..5279b1dc8 --- /dev/null +++ b/tests/mocks/mock_feetech.py @@ -0,0 +1,428 @@ +import abc +from typing import Callable + +import scservo_sdk as scs +import serial +from mock_serial import MockSerial + +from lerobot.common.motors.feetech.feetech import _split_into_byte_chunks, patch_setPacketTimeout + +from .mock_serial_patch import WaitableStub + + +class MockFeetechPacket(abc.ABC): + @classmethod + def build(cls, scs_id: int, params: list[int], length: int, *args, **kwargs) -> bytes: + packet = cls._build(scs_id, params, length, *args, **kwargs) + packet = cls._add_checksum(packet) + return bytes(packet) + + @abc.abstractclassmethod + def _build(cls, scs_id: int, params: list[int], length: int, *args, **kwargs) -> list[int]: + pass + + @staticmethod + def _add_checksum(packet: list[int]) -> list[int]: + checksum = 0 + for id_ in range(2, len(packet) - 1): # except header & checksum + checksum += packet[id_] + + packet[-1] = ~checksum & 0xFF + + return packet + + +class MockInstructionPacket(MockFeetechPacket): + """ + Helper class to build valid Feetech Instruction Packets. + + Instruction Packet structure + (from https://files.waveshare.com/upload/2/27/Communication_Protocol_User_Manual-EN%28191218-0923%29.pdf) + + | Header | Packet ID | Length | Instruction | Params | Checksum | + | --------- | --------- | ------ | ----------- | ----------------- | -------- | + | 0xFF 0xFF | ID | Len | Instr | Param 1 … Param N | Sum | + + """ + + @classmethod + def _build(cls, scs_id: int, params: list[int], length: int, instruction: int) -> list[int]: + return [ + 0xFF, 0xFF, # header + scs_id, # servo id + length, # length + instruction, # instruction type + *params, # data bytes + 0x00, # placeholder for checksum + ] # fmt: skip + + @classmethod + def ping( + cls, + scs_id: int, + ) -> bytes: + """ + Builds a "Ping" broadcast instruction. + + No parameters required. + """ + return cls.build(scs_id=scs_id, params=[], length=2, instruction=scs.INST_PING) + + @classmethod + def read( + cls, + scs_id: int, + start_address: int, + data_length: int, + ) -> bytes: + """ + Builds a "Read" instruction. + + The parameters for Read are: + param[0] = start_address + param[1] = data_length + + And 'length' = 4, where: + +1 is for instruction byte, + +1 is for the address byte, + +1 is for the length bytes, + +1 is for the checksum at the end. + """ + params = [start_address, data_length] + length = 4 + return cls.build(scs_id=scs_id, params=params, length=length, instruction=scs.INST_READ) + + @classmethod + def write( + cls, + scs_id: int, + value: int, + start_address: int, + data_length: int, + ) -> bytes: + """ + Builds a "Write" instruction. + + The parameters for Write are: + param[0] = start_address L + param[1] = start_address H + param[2] = 1st Byte + param[3] = 2nd Byte + ... + param[1+X] = X-th Byte + + And 'length' = data_length + 3, where: + +1 is for instruction byte, + +1 is for the length bytes, + +1 is for the checksum at the end. + """ + data = _split_into_byte_chunks(value, data_length) + params = [start_address, *data] + length = data_length + 3 + return cls.build(scs_id=scs_id, params=params, length=length, instruction=scs.INST_WRITE) + + @classmethod + def sync_read( + cls, + scs_ids: list[int], + start_address: int, + data_length: int, + ) -> bytes: + """ + Builds a "Sync_Read" broadcast instruction. + + The parameters for Sync Read are: + param[0] = start_address + param[1] = data_length + param[2+] = motor IDs to read from + + And 'length' = (number_of_params + 4), where: + +1 is for instruction byte, + +1 is for the address byte, + +1 is for the length bytes, + +1 is for the checksum at the end. + """ + params = [start_address, data_length, *scs_ids] + length = len(scs_ids) + 4 + return cls.build( + scs_id=scs.BROADCAST_ID, params=params, length=length, instruction=scs.INST_SYNC_READ + ) + + @classmethod + def sync_write( + cls, + ids_values: dict[int, int], + start_address: int, + data_length: int, + ) -> bytes: + """ + Builds a "Sync_Write" broadcast instruction. + + The parameters for Sync_Write are: + param[0] = start_address + param[1] = data_length + param[2] = [1st motor] ID + param[2+1] = [1st motor] 1st Byte + param[2+2] = [1st motor] 2nd Byte + ... + param[5+X] = [1st motor] X-th Byte + param[6] = [2nd motor] ID + param[6+1] = [2nd motor] 1st Byte + param[6+2] = [2nd motor] 2nd Byte + ... + param[6+X] = [2nd motor] X-th Byte + + And 'length' = ((number_of_params * 1 + data_length) + 4), where: + +1 is for instruction byte, + +1 is for the address byte, + +1 is for the length bytes, + +1 is for the checksum at the end. + """ + data = [] + for id_, value in ids_values.items(): + split_value = _split_into_byte_chunks(value, data_length) + data += [id_, *split_value] + params = [start_address, data_length, *data] + length = len(ids_values) * (1 + data_length) + 4 + return cls.build( + scs_id=scs.BROADCAST_ID, params=params, length=length, instruction=scs.INST_SYNC_WRITE + ) + + +class MockStatusPacket(MockFeetechPacket): + """ + Helper class to build valid Feetech Status Packets. + + Status Packet structure + (from https://files.waveshare.com/upload/2/27/Communication_Protocol_User_Manual-EN%28191218-0923%29.pdf) + + | Header | Packet ID | Length | Error | Params | Checksum | + | --------- | --------- | ------ | ----- | ----------------- | -------- | + | 0xFF 0xFF | ID | Len | Err | Param 1 … Param N | Sum | + + """ + + @classmethod + def _build(cls, scs_id: int, params: list[int], length: int, error: int = 0) -> list[int]: + return [ + 0xFF, 0xFF, # header + scs_id, # servo id + length, # length + error, # status + *params, # data bytes + 0x00, # placeholder for checksum + ] # fmt: skip + + @classmethod + def ping(cls, scs_id: int, error: int = 0) -> bytes: + """Builds a 'Ping' status packet. + + Args: + scs_id (int): ID of the servo responding. + error (int, optional): Error to be returned. Defaults to 0 (success). + + Returns: + bytes: The raw 'Ping' status packet ready to be sent through serial. + """ + return cls.build(scs_id, params=[], length=2, error=error) + + @classmethod + def read(cls, scs_id: int, value: int, param_length: int, error: int = 0) -> bytes: + """Builds a 'Read' status packet. + + Args: + scs_id (int): ID of the servo responding. + value (int): Desired value to be returned in the packet. + param_length (int): The address length as reported in the control table. + + Returns: + bytes: The raw 'Sync Read' status packet ready to be sent through serial. + """ + params = _split_into_byte_chunks(value, param_length) + length = param_length + 2 + return cls.build(scs_id, params=params, length=length, error=error) + + +class MockPortHandler(scs.PortHandler): + """ + This class overwrite the 'setupPort' method of the Feetech PortHandler because it can specify + baudrates that are not supported with a serial port on MacOS. + """ + + def setupPort(self, cflag_baud): # noqa: N802 + if self.is_open: + self.closePort() + + self.ser = serial.Serial( + port=self.port_name, + # baudrate=self.baudrate, <- This will fail on MacOS + # parity = serial.PARITY_ODD, + # stopbits = serial.STOPBITS_TWO, + bytesize=serial.EIGHTBITS, + timeout=0, + ) + self.is_open = True + self.ser.reset_input_buffer() + self.tx_time_per_byte = (1000.0 / self.baudrate) * 10.0 + + return True + + def setPacketTimeout(self, packet_length): # noqa: N802 + return patch_setPacketTimeout(self, packet_length) + + +class MockMotors(MockSerial): + """ + This class will simulate physical motors by responding with valid status packets upon receiving some + instruction packets. It is meant to test MotorsBus classes. + """ + + def __init__(self): + super().__init__() + + @property + def stubs(self) -> dict[str, WaitableStub]: + return super().stubs + + def stub(self, *, name=None, **kwargs): + new_stub = WaitableStub(**kwargs) + self._MockSerial__stubs[name or new_stub.receive_bytes] = new_stub + return new_stub + + def build_broadcast_ping_stub(self, ids: list[int] | None = None, num_invalid_try: int = 0) -> str: + ping_request = MockInstructionPacket.ping(scs.BROADCAST_ID) + return_packets = b"".join(MockStatusPacket.ping(id_) for id_ in ids) + ping_response = self._build_send_fn(return_packets, num_invalid_try) + stub_name = "Ping_" + "_".join([str(id_) for id_ in ids]) + self.stub( + name=stub_name, + receive_bytes=ping_request, + send_fn=ping_response, + ) + return stub_name + + def build_ping_stub(self, scs_id: int, num_invalid_try: int = 0, error: int = 0) -> str: + ping_request = MockInstructionPacket.ping(scs_id) + return_packet = MockStatusPacket.ping(scs_id, error) + ping_response = self._build_send_fn(return_packet, num_invalid_try) + stub_name = f"Ping_{scs_id}_{error}" + self.stub( + name=stub_name, + receive_bytes=ping_request, + send_fn=ping_response, + ) + return stub_name + + def build_read_stub( + self, + address: int, + length: int, + scs_id: int, + value: int, + reply: bool = True, + error: int = 0, + num_invalid_try: int = 0, + ) -> str: + read_request = MockInstructionPacket.read(scs_id, address, length) + return_packet = MockStatusPacket.read(scs_id, value, length, error) if reply else b"" + read_response = self._build_send_fn(return_packet, num_invalid_try) + stub_name = f"Read_{address}_{length}_{scs_id}_{value}_{error}" + self.stub( + name=stub_name, + receive_bytes=read_request, + send_fn=read_response, + ) + return stub_name + + def build_write_stub( + self, + address: int, + length: int, + scs_id: int, + value: int, + reply: bool = True, + error: int = 0, + num_invalid_try: int = 0, + ) -> str: + sync_read_request = MockInstructionPacket.write(scs_id, value, address, length) + return_packet = MockStatusPacket.build(scs_id, params=[], length=2, error=error) if reply else b"" + stub_name = f"Write_{address}_{length}_{scs_id}" + self.stub( + name=stub_name, + receive_bytes=sync_read_request, + send_fn=self._build_send_fn(return_packet, num_invalid_try), + ) + return stub_name + + def build_sync_read_stub( + self, + address: int, + length: int, + ids_values: dict[int, int], + reply: bool = True, + num_invalid_try: int = 0, + ) -> str: + sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length) + return_packets = ( + b"".join(MockStatusPacket.read(id_, pos, length) for id_, pos in ids_values.items()) + if reply + else b"" + ) + sync_read_response = self._build_send_fn(return_packets, num_invalid_try) + stub_name = f"Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values]) + self.stub( + name=stub_name, + receive_bytes=sync_read_request, + send_fn=sync_read_response, + ) + return stub_name + + def build_sequential_sync_read_stub( + self, address: int, length: int, ids_values: dict[int, list[int]] | None = None + ) -> str: + sequence_length = len(next(iter(ids_values.values()))) + assert all(len(positions) == sequence_length for positions in ids_values.values()) + sync_read_request = MockInstructionPacket.sync_read(list(ids_values), address, length) + sequential_packets = [] + for count in range(sequence_length): + return_packets = b"".join( + MockStatusPacket.read(id_, positions[count], length) for id_, positions in ids_values.items() + ) + sequential_packets.append(return_packets) + + sync_read_response = self._build_sequential_send_fn(sequential_packets) + stub_name = f"Seq_Sync_Read_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values]) + self.stub( + name=stub_name, + receive_bytes=sync_read_request, + send_fn=sync_read_response, + ) + return stub_name + + def build_sync_write_stub( + self, address: int, length: int, ids_values: dict[int, int], num_invalid_try: int = 0 + ) -> str: + sync_read_request = MockInstructionPacket.sync_write(ids_values, address, length) + stub_name = f"Sync_Write_{address}_{length}_" + "_".join([str(id_) for id_ in ids_values]) + self.stub( + name=stub_name, + receive_bytes=sync_read_request, + send_fn=self._build_send_fn(b"", num_invalid_try), + ) + return stub_name + + @staticmethod + def _build_send_fn(packet: bytes, num_invalid_try: int = 0) -> Callable[[int], bytes]: + def send_fn(_call_count: int) -> bytes: + if num_invalid_try >= _call_count: + return b"" + return packet + + return send_fn + + @staticmethod + def _build_sequential_send_fn(packets: list[bytes]) -> Callable[[int], bytes]: + def send_fn(_call_count: int) -> bytes: + return packets[_call_count - 1] + + return send_fn diff --git a/tests/mocks/mock_motors_bus.py b/tests/mocks/mock_motors_bus.py new file mode 100644 index 000000000..e322eae8a --- /dev/null +++ b/tests/mocks/mock_motors_bus.py @@ -0,0 +1,138 @@ +# ruff: noqa: N802 + +from lerobot.common.motors.motors_bus import ( + Motor, + MotorsBus, +) + +DUMMY_CTRL_TABLE_1 = { + "Firmware_Version": (0, 1), + "Model_Number": (1, 2), + "Present_Position": (3, 4), + "Goal_Position": (11, 2), +} + +DUMMY_CTRL_TABLE_2 = { + "Model_Number": (0, 2), + "Firmware_Version": (2, 1), + "Present_Position": (3, 4), + "Present_Velocity": (7, 4), + "Goal_Position": (11, 4), + "Goal_Velocity": (15, 4), + "Lock": (19, 1), +} + +DUMMY_MODEL_CTRL_TABLE = { + "model_1": DUMMY_CTRL_TABLE_1, + "model_2": DUMMY_CTRL_TABLE_2, + "model_3": DUMMY_CTRL_TABLE_2, +} + +DUMMY_BAUDRATE_TABLE = { + 0: 1_000_000, + 1: 500_000, + 2: 250_000, +} + +DUMMY_MODEL_BAUDRATE_TABLE = { + "model_1": DUMMY_BAUDRATE_TABLE, + "model_2": DUMMY_BAUDRATE_TABLE, + "model_3": DUMMY_BAUDRATE_TABLE, +} + +DUMMY_ENCODING_TABLE = { + "Present_Position": 8, + "Goal_Position": 10, +} + +DUMMY_MODEL_ENCODING_TABLE = { + "model_1": DUMMY_ENCODING_TABLE, + "model_2": DUMMY_ENCODING_TABLE, + "model_3": DUMMY_ENCODING_TABLE, +} + +DUMMY_MODEL_NUMBER_TABLE = { + "model_1": 1234, + "model_2": 5678, + "model_3": 5799, +} + +DUMMY_MODEL_RESOLUTION_TABLE = { + "model_1": 4096, + "model_2": 1024, + "model_3": 4096, +} + + +class MockPortHandler: + def __init__(self, port_name): + self.is_open: bool = False + self.baudrate: int + self.packet_start_time: float + self.packet_timeout: float + self.tx_time_per_byte: float + self.is_using: bool = False + self.port_name: str = port_name + self.ser = None + + def openPort(self): + self.is_open = True + return self.is_open + + def closePort(self): + self.is_open = False + + def clearPort(self): ... + def setPortName(self, port_name): + self.port_name = port_name + + def getPortName(self): + return self.port_name + + def setBaudRate(self, baudrate): + self.baudrate: baudrate + + def getBaudRate(self): + return self.baudrate + + def getBytesAvailable(self): ... + def readPort(self, length): ... + def writePort(self, packet): ... + def setPacketTimeout(self, packet_length): ... + def setPacketTimeoutMillis(self, msec): ... + def isPacketTimeout(self): ... + def getCurrentTime(self): ... + def getTimeSinceStart(self): ... + def setupPort(self, cflag_baud): ... + def getCFlagBaud(self, baudrate): ... + + +class MockMotorsBus(MotorsBus): + available_baudrates = [500_000, 1_000_000] + default_timeout = 1000 + model_baudrate_table = DUMMY_MODEL_BAUDRATE_TABLE + model_ctrl_table = DUMMY_MODEL_CTRL_TABLE + model_encoding_table = DUMMY_MODEL_ENCODING_TABLE + model_number_table = DUMMY_MODEL_NUMBER_TABLE + model_resolution_table = DUMMY_MODEL_RESOLUTION_TABLE + normalized_data = ["Present_Position", "Goal_Position"] + + def __init__(self, port: str, motors: dict[str, Motor]): + super().__init__(port, motors) + self.port_handler = MockPortHandler(port) + + def _assert_protocol_is_compatible(self, instruction_name): ... + def _handshake(self): ... + def _find_single_motor(self, motor, initial_baudrate): ... + def configure_motors(self): ... + def is_calibrated(self): ... + def read_calibration(self): ... + def write_calibration(self, calibration_dict): ... + def disable_torque(self, motors, num_retry): ... + def _disable_torque(self, motor, model, num_retry): ... + def enable_torque(self, motors, num_retry): ... + def _get_half_turn_homings(self, positions): ... + def _encode_sign(self, data_name, ids_values): ... + def _decode_sign(self, data_name, ids_values): ... + def _split_into_byte_chunks(self, value, length): ... + def broadcast_ping(self, num_retry, raise_on_error): ... diff --git a/tests/mocks/mock_robot.py b/tests/mocks/mock_robot.py new file mode 100644 index 000000000..40d8fbde6 --- /dev/null +++ b/tests/mocks/mock_robot.py @@ -0,0 +1,112 @@ +import random +from dataclasses import dataclass, field +from functools import cached_property +from typing import Any + +from lerobot.common.cameras import CameraConfig, make_cameras_from_configs +from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.common.robots import Robot, RobotConfig + + +@RobotConfig.register_subclass("mock_robot") +@dataclass +class MockRobotConfig(RobotConfig): + n_motors: int = 3 + cameras: dict[str, CameraConfig] = field(default_factory=dict) + random_values: bool = True + static_values: list[float] | None = None + calibrated: bool = True + + def __post_init__(self): + if self.n_motors < 1: + raise ValueError(self.n_motors) + + if self.random_values and self.static_values is not None: + raise ValueError("Choose either random values or static values") + + if self.static_values is not None and len(self.static_values) != self.n_motors: + raise ValueError("Specify the same number of static values as motors") + + if len(self.cameras) > 0: + raise NotImplementedError # TODO with the cameras refactor + + +class MockRobot(Robot): + """Mock Robot to be used for testing.""" + + config_class = MockRobotConfig + name = "mock_robot" + + def __init__(self, config: MockRobotConfig): + super().__init__(config) + self.config = config + self._is_connected = False + self._is_calibrated = config.calibrated + self.motors = [f"motor_{i + 1}" for i in range(config.n_motors)] + self.cameras = make_cameras_from_configs(config.cameras) + + @property + def _motors_ft(self) -> dict[str, type]: + return {f"{motor}.pos": float for motor in self.motors} + + @property + def _cameras_ft(self) -> dict[str, tuple]: + return { + cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras + } + + @cached_property + def observation_features(self) -> dict[str, type | tuple]: + return {**self._motors_ft, **self._cameras_ft} + + @cached_property + def action_features(self) -> dict[str, type]: + return self._motors_ft + + @property + def is_connected(self) -> bool: + return self._is_connected + + def connect(self, calibrate: bool = True) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + self._is_connected = True + if calibrate: + self.calibrate() + + @property + def is_calibrated(self) -> bool: + return self._is_calibrated + + def calibrate(self) -> None: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + self._is_calibrated = True + + def configure(self) -> None: + pass + + def get_observation(self) -> dict[str, Any]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.config.random_values: + return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors} + else: + return { + f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True) + } + + def send_action(self, action: dict[str, Any]) -> dict[str, Any]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + return action + + def disconnect(self) -> None: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + self._is_connected = False diff --git a/tests/mocks/mock_serial_patch.py b/tests/mocks/mock_serial_patch.py new file mode 100644 index 000000000..e39923188 --- /dev/null +++ b/tests/mocks/mock_serial_patch.py @@ -0,0 +1,35 @@ +import threading +import time + +from mock_serial.mock_serial import Stub + + +class WaitableStub(Stub): + """ + In some situations, a test might be checking if a stub has been called before `MockSerial` thread had time + to read, match, and call the stub. In these situations, the test can fail randomly. + + Use `wait_called()` or `wait_calls()` to block until the stub is called, avoiding race conditions. + + Proposed fix: + https://github.com/benthorner/mock_serial/pull/3 + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._event = threading.Event() + + def call(self): + self._event.set() + return super().call() + + def wait_called(self, timeout: float = 1.0): + return self._event.wait(timeout) + + def wait_calls(self, min_calls: int = 1, timeout: float = 1.0): + start = time.perf_counter() + while time.perf_counter() - start < timeout: + if self.calls >= min_calls: + return self.calls + time.sleep(0.005) + raise TimeoutError(f"Stub not called {min_calls} times within {timeout} seconds.") diff --git a/tests/mocks/mock_teleop.py b/tests/mocks/mock_teleop.py new file mode 100644 index 000000000..a7f5cad35 --- /dev/null +++ b/tests/mocks/mock_teleop.py @@ -0,0 +1,94 @@ +import random +from dataclasses import dataclass +from functools import cached_property +from typing import Any + +from lerobot.common.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError +from lerobot.common.teleoperators import Teleoperator, TeleoperatorConfig + + +@TeleoperatorConfig.register_subclass("mock_teleop") +@dataclass +class MockTeleopConfig(TeleoperatorConfig): + n_motors: int = 3 + random_values: bool = True + static_values: list[float] | None = None + calibrated: bool = True + + def __post_init__(self): + if self.n_motors < 1: + raise ValueError(self.n_motors) + + if self.random_values and self.static_values is not None: + raise ValueError("Choose either random values or static values") + + if self.static_values is not None and len(self.static_values) != self.n_motors: + raise ValueError("Specify the same number of static values as motors") + + +class MockTeleop(Teleoperator): + """Mock Teleoperator to be used for testing.""" + + config_class = MockTeleopConfig + name = "mock_teleop" + + def __init__(self, config: MockTeleopConfig): + super().__init__(config) + self.config = config + self._is_connected = False + self._is_calibrated = config.calibrated + self.motors = [f"motor_{i + 1}" for i in range(config.n_motors)] + + @cached_property + def action_features(self) -> dict[str, type]: + return {f"{motor}.pos": float for motor in self.motors} + + @cached_property + def feedback_features(self) -> dict[str, type]: + return {f"{motor}.pos": float for motor in self.motors} + + @property + def is_connected(self) -> bool: + return self._is_connected + + def connect(self, calibrate: bool = True) -> None: + if self.is_connected: + raise DeviceAlreadyConnectedError(f"{self} already connected") + + self._is_connected = True + if calibrate: + self.calibrate() + + @property + def is_calibrated(self) -> bool: + return self._is_calibrated + + def calibrate(self) -> None: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + self._is_calibrated = True + + def configure(self) -> None: + pass + + def get_action(self) -> dict[str, Any]: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + if self.config.random_values: + return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors} + else: + return { + f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True) + } + + def send_feedback(self, feedback: dict[str, Any]) -> None: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + def disconnect(self) -> None: + if not self.is_connected: + raise DeviceNotConnectedError(f"{self} is not connected.") + + self._is_connected = False diff --git a/tests/motors/mock_dynamixel_sdk.py b/tests/motors/mock_dynamixel_sdk.py deleted file mode 100644 index ee399f96d..000000000 --- a/tests/motors/mock_dynamixel_sdk.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration -and testing code logic that requires hardware and devices (e.g. robot arms, cameras) - -Warning: These mocked versions are minimalist. They do not exactly mock every behaviors -from the original classes and functions (e.g. return types might be None instead of boolean). -""" - -# from dynamixel_sdk import COMM_SUCCESS - -DEFAULT_BAUDRATE = 9_600 -COMM_SUCCESS = 0 # tx or rx packet communication success - - -def convert_to_bytes(value, bytes): - # TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform - # `convert_bytes_to_value` - del bytes # unused - return value - - -def get_default_motor_values(motor_index): - return { - # Key (int) are from X_SERIES_CONTROL_TABLE - 7: motor_index, # ID - 8: DEFAULT_BAUDRATE, # Baud_rate - 10: 0, # Drive_Mode - 64: 0, # Torque_Enable - # Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144 - # For other joints, 2560 will be autocorrected to be in calibration range - 132: 2560, # Present_Position - } - - -class PortHandler: - def __init__(self, port): - self.port = port - # factory default baudrate - self.baudrate = DEFAULT_BAUDRATE - - def openPort(self): # noqa: N802 - return True - - def closePort(self): # noqa: N802 - pass - - def setPacketTimeoutMillis(self, timeout_ms): # noqa: N802 - del timeout_ms # unused - - def getBaudRate(self): # noqa: N802 - return self.baudrate - - def setBaudRate(self, baudrate): # noqa: N802 - self.baudrate = baudrate - - -class PacketHandler: - def __init__(self, protocol_version): - del protocol_version # unused - # Use packet_handler.data to communicate across Read and Write - self.data = {} - - -class GroupSyncRead: - def __init__(self, port_handler, packet_handler, address, bytes): - self.packet_handler = packet_handler - - def addParam(self, motor_index): # noqa: N802 - # Initialize motor default values - if motor_index not in self.packet_handler.data: - self.packet_handler.data[motor_index] = get_default_motor_values(motor_index) - - def txRxPacket(self): # noqa: N802 - return COMM_SUCCESS - - def getData(self, index, address, bytes): # noqa: N802 - return self.packet_handler.data[index][address] - - -class GroupSyncWrite: - def __init__(self, port_handler, packet_handler, address, bytes): - self.packet_handler = packet_handler - self.address = address - - def addParam(self, index, data): # noqa: N802 - # Initialize motor default values - if index not in self.packet_handler.data: - self.packet_handler.data[index] = get_default_motor_values(index) - self.changeParam(index, data) - - def txPacket(self): # noqa: N802 - return COMM_SUCCESS - - def changeParam(self, index, data): # noqa: N802 - self.packet_handler.data[index][self.address] = data diff --git a/tests/motors/mock_scservo_sdk.py b/tests/motors/mock_scservo_sdk.py deleted file mode 100644 index 37f6d0d56..000000000 --- a/tests/motors/mock_scservo_sdk.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Mocked classes and functions from dynamixel_sdk to allow for continuous integration -and testing code logic that requires hardware and devices (e.g. robot arms, cameras) - -Warning: These mocked versions are minimalist. They do not exactly mock every behaviors -from the original classes and functions (e.g. return types might be None instead of boolean). -""" - -# from dynamixel_sdk import COMM_SUCCESS - -DEFAULT_BAUDRATE = 1_000_000 -COMM_SUCCESS = 0 # tx or rx packet communication success - - -def convert_to_bytes(value, bytes): - # TODO(rcadene): remove need to mock `convert_to_bytes` by implemented the inverse transform - # `convert_bytes_to_value` - del bytes # unused - return value - - -def get_default_motor_values(motor_index): - return { - # Key (int) are from SCS_SERIES_CONTROL_TABLE - 5: motor_index, # ID - 6: DEFAULT_BAUDRATE, # Baud_rate - 10: 0, # Drive_Mode - 21: 32, # P_Coefficient - 22: 32, # D_Coefficient - 23: 0, # I_Coefficient - 40: 0, # Torque_Enable - 41: 254, # Acceleration - 31: -2047, # Offset - 33: 0, # Mode - 55: 1, # Lock - # Set 2560 since calibration values for Aloha gripper is between start_pos=2499 and end_pos=3144 - # For other joints, 2560 will be autocorrected to be in calibration range - 56: 2560, # Present_Position - 58: 0, # Present_Speed - 69: 0, # Present_Current - 85: 150, # Maximum_Acceleration - } - - -class PortHandler: - def __init__(self, port): - self.port = port - # factory default baudrate - self.baudrate = DEFAULT_BAUDRATE - self.ser = SerialMock() - - def openPort(self): # noqa: N802 - return True - - def closePort(self): # noqa: N802 - pass - - def setPacketTimeoutMillis(self, timeout_ms): # noqa: N802 - del timeout_ms # unused - - def getBaudRate(self): # noqa: N802 - return self.baudrate - - def setBaudRate(self, baudrate): # noqa: N802 - self.baudrate = baudrate - - -class PacketHandler: - def __init__(self, protocol_version): - del protocol_version # unused - # Use packet_handler.data to communicate across Read and Write - self.data = {} - - -class GroupSyncRead: - def __init__(self, port_handler, packet_handler, address, bytes): - self.packet_handler = packet_handler - - def addParam(self, motor_index): # noqa: N802 - # Initialize motor default values - if motor_index not in self.packet_handler.data: - self.packet_handler.data[motor_index] = get_default_motor_values(motor_index) - - def txRxPacket(self): # noqa: N802 - return COMM_SUCCESS - - def getData(self, index, address, bytes): # noqa: N802 - return self.packet_handler.data[index][address] - - -class GroupSyncWrite: - def __init__(self, port_handler, packet_handler, address, bytes): - self.packet_handler = packet_handler - self.address = address - - def addParam(self, index, data): # noqa: N802 - if index not in self.packet_handler.data: - self.packet_handler.data[index] = get_default_motor_values(index) - self.changeParam(index, data) - - def txPacket(self): # noqa: N802 - return COMM_SUCCESS - - def changeParam(self, index, data): # noqa: N802 - self.packet_handler.data[index][self.address] = data - - -class SerialMock: - def reset_output_buffer(self): - pass - - def reset_input_buffer(self): - pass diff --git a/tests/motors/test_dynamixel.py b/tests/motors/test_dynamixel.py new file mode 100644 index 000000000..dcce8f691 --- /dev/null +++ b/tests/motors/test_dynamixel.py @@ -0,0 +1,400 @@ +import re +import sys +from typing import Generator +from unittest.mock import MagicMock, patch + +import pytest + +from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.common.motors.dynamixel import MODEL_NUMBER_TABLE, DynamixelMotorsBus +from lerobot.common.motors.dynamixel.tables import X_SERIES_CONTROL_TABLE +from lerobot.common.utils.encoding_utils import encode_twos_complement + +try: + import dynamixel_sdk as dxl + + from tests.mocks.mock_dynamixel import MockMotors, MockPortHandler +except (ImportError, ModuleNotFoundError): + pytest.skip("dynamixel_sdk not available", allow_module_level=True) + + +@pytest.fixture(autouse=True) +def patch_port_handler(): + if sys.platform == "darwin": + with patch.object(dxl, "PortHandler", MockPortHandler): + yield + else: + yield + + +@pytest.fixture +def mock_motors() -> Generator[MockMotors, None, None]: + motors = MockMotors() + motors.open() + yield motors + motors.close() + + +@pytest.fixture +def dummy_motors() -> dict[str, Motor]: + return { + "dummy_1": Motor(1, "xl430-w250", MotorNormMode.RANGE_M100_100), + "dummy_2": Motor(2, "xm540-w270", MotorNormMode.RANGE_M100_100), + "dummy_3": Motor(3, "xl330-m077", MotorNormMode.RANGE_M100_100), + } + + +@pytest.fixture +def dummy_calibration(dummy_motors) -> dict[str, MotorCalibration]: + drive_modes = [0, 1, 0] + homings = [-709, -2006, 1624] + mins = [43, 27, 145] + maxes = [1335, 3608, 3999] + calibration = {} + for motor, m in dummy_motors.items(): + calibration[motor] = MotorCalibration( + id=m.id, + drive_mode=drive_modes[m.id - 1], + homing_offset=homings[m.id - 1], + range_min=mins[m.id - 1], + range_max=maxes[m.id - 1], + ) + return calibration + + +@pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}") +def test_autouse_patch(): + """Ensures that the autouse fixture correctly patches dxl.PortHandler with MockPortHandler.""" + assert dxl.PortHandler is MockPortHandler + + +@pytest.mark.parametrize( + "value, length, expected", + [ + (0x12, 1, [0x12]), + (0x1234, 2, [0x34, 0x12]), + (0x12345678, 4, [0x78, 0x56, 0x34, 0x12]), + ], + ids=[ + "1 byte", + "2 bytes", + "4 bytes", + ], +) # fmt: skip +def test__split_into_byte_chunks(value, length, expected): + bus = DynamixelMotorsBus("", {}) + assert bus._split_into_byte_chunks(value, length) == expected + + +def test_abc_implementation(dummy_motors): + """Instantiation should raise an error if the class doesn't implement abstract methods/properties.""" + DynamixelMotorsBus(port="/dev/dummy-port", motors=dummy_motors) + + +@pytest.mark.parametrize("id_", [1, 2, 3]) +def test_ping(id_, mock_motors, dummy_motors): + expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model] + stub = mock_motors.build_ping_stub(id_, expected_model_nb) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + + ping_model_nb = bus.ping(id_) + + assert ping_model_nb == expected_model_nb + assert mock_motors.stubs[stub].called + + +def test_broadcast_ping(mock_motors, dummy_motors): + models = {m.id: m.model for m in dummy_motors.values()} + expected_model_nbs = {id_: MODEL_NUMBER_TABLE[model] for id_, model in models.items()} + stub = mock_motors.build_broadcast_ping_stub(expected_model_nbs) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + + ping_model_nbs = bus.broadcast_ping() + + assert ping_model_nbs == expected_model_nbs + assert mock_motors.stubs[stub].called + + +@pytest.mark.parametrize( + "addr, length, id_, value", + [ + (0, 1, 1, 2), + (10, 2, 2, 999), + (42, 4, 3, 1337), + ], +) +def test__read(addr, length, id_, value, mock_motors, dummy_motors): + stub = mock_motors.build_read_stub(addr, length, id_, value) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + + read_value, _, _ = bus._read(addr, length, id_) + + assert mock_motors.stubs[stub].called + assert read_value == value + + +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__read_error(raise_on_error, mock_motors, dummy_motors): + addr, length, id_, value, error = (10, 4, 1, 1337, dxl.ERRNUM_DATA_LIMIT) + stub = mock_motors.build_read_stub(addr, length, id_, value, error=error) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + + if raise_on_error: + with pytest.raises( + RuntimeError, match=re.escape("[RxPacketError] The data value exceeds the limit value!") + ): + bus._read(addr, length, id_, raise_on_error=raise_on_error) + else: + _, _, read_error = bus._read(addr, length, id_, raise_on_error=raise_on_error) + assert read_error == error + + assert mock_motors.stubs[stub].called + + +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__read_comm(raise_on_error, mock_motors, dummy_motors): + addr, length, id_, value = (10, 4, 1, 1337) + stub = mock_motors.build_read_stub(addr, length, id_, value, reply=False) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + + if raise_on_error: + with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): + bus._read(addr, length, id_, raise_on_error=raise_on_error) + else: + _, read_comm, _ = bus._read(addr, length, id_, raise_on_error=raise_on_error) + assert read_comm == dxl.COMM_RX_TIMEOUT + + assert mock_motors.stubs[stub].called + + +@pytest.mark.parametrize( + "addr, length, id_, value", + [ + (0, 1, 1, 2), + (10, 2, 2, 999), + (42, 4, 3, 1337), + ], +) +def test__write(addr, length, id_, value, mock_motors, dummy_motors): + stub = mock_motors.build_write_stub(addr, length, id_, value) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + + comm, error = bus._write(addr, length, id_, value) + + assert mock_motors.stubs[stub].called + assert comm == dxl.COMM_SUCCESS + assert error == 0 + + +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__write_error(raise_on_error, mock_motors, dummy_motors): + addr, length, id_, value, error = (10, 4, 1, 1337, dxl.ERRNUM_DATA_LIMIT) + stub = mock_motors.build_write_stub(addr, length, id_, value, error=error) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + + if raise_on_error: + with pytest.raises( + RuntimeError, match=re.escape("[RxPacketError] The data value exceeds the limit value!") + ): + bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + else: + _, write_error = bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + assert write_error == error + + assert mock_motors.stubs[stub].called + + +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__write_comm(raise_on_error, mock_motors, dummy_motors): + addr, length, id_, value = (10, 4, 1, 1337) + stub = mock_motors.build_write_stub(addr, length, id_, value, reply=False) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + + if raise_on_error: + with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): + bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + else: + write_comm, _ = bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + assert write_comm == dxl.COMM_RX_TIMEOUT + + assert mock_motors.stubs[stub].called + + +@pytest.mark.parametrize( + "addr, length, ids_values", + [ + (0, 1, {1: 4}), + (10, 2, {1: 1337, 2: 42}), + (42, 4, {1: 1337, 2: 42, 3: 4016}), + ], + ids=["1 motor", "2 motors", "3 motors"], +) +def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors): + stub = mock_motors.build_sync_read_stub(addr, length, ids_values) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + + read_values, _ = bus._sync_read(addr, length, list(ids_values)) + + assert mock_motors.stubs[stub].called + assert read_values == ids_values + + +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors): + addr, length, ids_values = (10, 4, {1: 1337}) + stub = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + + if raise_on_error: + with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): + bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) + else: + _, read_comm = bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) + assert read_comm == dxl.COMM_RX_TIMEOUT + + assert mock_motors.stubs[stub].called + + +@pytest.mark.parametrize( + "addr, length, ids_values", + [ + (0, 1, {1: 4}), + (10, 2, {1: 1337, 2: 42}), + (42, 4, {1: 1337, 2: 42, 3: 4016}), + ], + ids=["1 motor", "2 motors", "3 motors"], +) +def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors): + stub = mock_motors.build_sync_write_stub(addr, length, ids_values) + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + + comm = bus._sync_write(addr, length, ids_values) + + assert mock_motors.stubs[stub].wait_called() + assert comm == dxl.COMM_SUCCESS + + +def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration): + drive_modes = {m.id: m.drive_mode for m in dummy_calibration.values()} + encoded_homings = {m.id: encode_twos_complement(m.homing_offset, 4) for m in dummy_calibration.values()} + mins = {m.id: m.range_min for m in dummy_calibration.values()} + maxes = {m.id: m.range_max for m in dummy_calibration.values()} + drive_modes_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Drive_Mode"], drive_modes) + offsets_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], encoded_homings) + mins_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Min_Position_Limit"], mins) + maxes_stub = mock_motors.build_sync_read_stub(*X_SERIES_CONTROL_TABLE["Max_Position_Limit"], maxes) + bus = DynamixelMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + calibration=dummy_calibration, + ) + bus.connect(handshake=False) + + is_calibrated = bus.is_calibrated + + assert is_calibrated + assert mock_motors.stubs[drive_modes_stub].called + assert mock_motors.stubs[offsets_stub].called + assert mock_motors.stubs[mins_stub].called + assert mock_motors.stubs[maxes_stub].called + + +def test_reset_calibration(mock_motors, dummy_motors): + write_homing_stubs = [] + write_mins_stubs = [] + write_maxes_stubs = [] + for motor in dummy_motors.values(): + write_homing_stubs.append( + mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], motor.id, 0) + ) + write_mins_stubs.append( + mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Min_Position_Limit"], motor.id, 0) + ) + write_maxes_stubs.append( + mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Max_Position_Limit"], motor.id, 4095) + ) + + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + + bus.reset_calibration() + + assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs) + assert all(mock_motors.stubs[stub].called for stub in write_mins_stubs) + assert all(mock_motors.stubs[stub].called for stub in write_maxes_stubs) + + +def test_set_half_turn_homings(mock_motors, dummy_motors): + """ + For this test, we assume that the homing offsets are already 0 such that + Present_Position == Actual_Position + """ + current_positions = { + 1: 1337, + 2: 42, + 3: 3672, + } + expected_homings = { + 1: 710, # 2047 - 1337 + 2: 2005, # 2047 - 42 + 3: -1625, # 2047 - 3672 + } + read_pos_stub = mock_motors.build_sync_read_stub( + *X_SERIES_CONTROL_TABLE["Present_Position"], current_positions + ) + write_homing_stubs = [] + for id_, homing in expected_homings.items(): + encoded_homing = encode_twos_complement(homing, 4) + stub = mock_motors.build_write_stub(*X_SERIES_CONTROL_TABLE["Homing_Offset"], id_, encoded_homing) + write_homing_stubs.append(stub) + + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + bus.reset_calibration = MagicMock() + + bus.set_half_turn_homings() + + bus.reset_calibration.assert_called_once() + assert mock_motors.stubs[read_pos_stub].called + assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs) + + +def test_record_ranges_of_motion(mock_motors, dummy_motors): + positions = { + 1: [351, 42, 1337], + 2: [28, 3600, 2444], + 3: [4002, 2999, 146], + } + expected_mins = { + "dummy_1": 42, + "dummy_2": 28, + "dummy_3": 146, + } + expected_maxes = { + "dummy_1": 1337, + "dummy_2": 3600, + "dummy_3": 4002, + } + read_pos_stub = mock_motors.build_sequential_sync_read_stub( + *X_SERIES_CONTROL_TABLE["Present_Position"], positions + ) + with patch("lerobot.common.motors.motors_bus.enter_pressed", side_effect=[False, True]): + bus = DynamixelMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + + mins, maxes = bus.record_ranges_of_motion(display_values=False) + + assert mock_motors.stubs[read_pos_stub].calls == 3 + assert mins == expected_mins + assert maxes == expected_maxes diff --git a/tests/motors/test_feetech.py b/tests/motors/test_feetech.py new file mode 100644 index 000000000..2e7b2ff77 --- /dev/null +++ b/tests/motors/test_feetech.py @@ -0,0 +1,443 @@ +import re +import sys +from typing import Generator +from unittest.mock import MagicMock, patch + +import pytest + +from lerobot.common.motors import Motor, MotorCalibration, MotorNormMode +from lerobot.common.motors.feetech import MODEL_NUMBER, MODEL_NUMBER_TABLE, FeetechMotorsBus +from lerobot.common.motors.feetech.tables import STS_SMS_SERIES_CONTROL_TABLE +from lerobot.common.utils.encoding_utils import encode_sign_magnitude + +try: + import scservo_sdk as scs + + from tests.mocks.mock_feetech import MockMotors, MockPortHandler +except (ImportError, ModuleNotFoundError): + pytest.skip("scservo_sdk not available", allow_module_level=True) + + +@pytest.fixture(autouse=True) +def patch_port_handler(): + if sys.platform == "darwin": + with patch.object(scs, "PortHandler", MockPortHandler): + yield + else: + yield + + +@pytest.fixture +def mock_motors() -> Generator[MockMotors, None, None]: + motors = MockMotors() + motors.open() + yield motors + motors.close() + + +@pytest.fixture +def dummy_motors() -> dict[str, Motor]: + return { + "dummy_1": Motor(1, "sts3215", MotorNormMode.RANGE_M100_100), + "dummy_2": Motor(2, "sts3215", MotorNormMode.RANGE_M100_100), + "dummy_3": Motor(3, "sts3215", MotorNormMode.RANGE_M100_100), + } + + +@pytest.fixture +def dummy_calibration(dummy_motors) -> dict[str, MotorCalibration]: + homings = [-709, -2006, 1624] + mins = [43, 27, 145] + maxes = [1335, 3608, 3999] + calibration = {} + for motor, m in dummy_motors.items(): + calibration[motor] = MotorCalibration( + id=m.id, + drive_mode=0, + homing_offset=homings[m.id - 1], + range_min=mins[m.id - 1], + range_max=maxes[m.id - 1], + ) + return calibration + + +@pytest.mark.skipif(sys.platform != "darwin", reason=f"No patching needed on {sys.platform=}") +def test_autouse_patch(): + """Ensures that the autouse fixture correctly patches scs.PortHandler with MockPortHandler.""" + assert scs.PortHandler is MockPortHandler + + +@pytest.mark.parametrize( + "protocol, value, length, expected", + [ + (0, 0x12, 1, [0x12]), + (1, 0x12, 1, [0x12]), + (0, 0x1234, 2, [0x34, 0x12]), + (1, 0x1234, 2, [0x12, 0x34]), + (0, 0x12345678, 4, [0x78, 0x56, 0x34, 0x12]), + (1, 0x12345678, 4, [0x56, 0x78, 0x12, 0x34]), + ], + ids=[ + "P0: 1 byte", + "P1: 1 byte", + "P0: 2 bytes", + "P1: 2 bytes", + "P0: 4 bytes", + "P1: 4 bytes", + ], +) # fmt: skip +def test__split_into_byte_chunks(protocol, value, length, expected): + bus = FeetechMotorsBus("", {}, protocol_version=protocol) + assert bus._split_into_byte_chunks(value, length) == expected + + +def test_abc_implementation(dummy_motors): + """Instantiation should raise an error if the class doesn't implement abstract methods/properties.""" + FeetechMotorsBus(port="/dev/dummy-port", motors=dummy_motors) + + +@pytest.mark.parametrize("id_", [1, 2, 3]) +def test_ping(id_, mock_motors, dummy_motors): + expected_model_nb = MODEL_NUMBER_TABLE[dummy_motors[f"dummy_{id_}"].model] + addr, length = MODEL_NUMBER + ping_stub = mock_motors.build_ping_stub(id_) + mobel_nb_stub = mock_motors.build_read_stub(addr, length, id_, expected_model_nb) + bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + bus.connect(handshake=False) + + ping_model_nb = bus.ping(id_) + + assert ping_model_nb == expected_model_nb + assert mock_motors.stubs[ping_stub].called + assert mock_motors.stubs[mobel_nb_stub].called + + +def test_broadcast_ping(mock_motors, dummy_motors): + models = {m.id: m.model for m in dummy_motors.values()} + addr, length = MODEL_NUMBER + ping_stub = mock_motors.build_broadcast_ping_stub(list(models)) + mobel_nb_stubs = [] + expected_model_nbs = {} + for id_, model in models.items(): + model_nb = MODEL_NUMBER_TABLE[model] + stub = mock_motors.build_read_stub(addr, length, id_, model_nb) + expected_model_nbs[id_] = model_nb + mobel_nb_stubs.append(stub) + bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + bus.connect(handshake=False) + + ping_model_nbs = bus.broadcast_ping() + + assert ping_model_nbs == expected_model_nbs + assert mock_motors.stubs[ping_stub].called + assert all(mock_motors.stubs[stub].called for stub in mobel_nb_stubs) + + +@pytest.mark.parametrize( + "addr, length, id_, value", + [ + (0, 1, 1, 2), + (10, 2, 2, 999), + (42, 4, 3, 1337), + ], +) +def test__read(addr, length, id_, value, mock_motors, dummy_motors): + stub = mock_motors.build_read_stub(addr, length, id_, value) + bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + bus.connect(handshake=False) + + read_value, _, _ = bus._read(addr, length, id_) + + assert mock_motors.stubs[stub].called + assert read_value == value + + +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__read_error(raise_on_error, mock_motors, dummy_motors): + addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE) + stub = mock_motors.build_read_stub(addr, length, id_, value, error=error) + bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + bus.connect(handshake=False) + + if raise_on_error: + with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")): + bus._read(addr, length, id_, raise_on_error=raise_on_error) + else: + _, _, read_error = bus._read(addr, length, id_, raise_on_error=raise_on_error) + assert read_error == error + + assert mock_motors.stubs[stub].called + + +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__read_comm(raise_on_error, mock_motors, dummy_motors): + addr, length, id_, value = (10, 4, 1, 1337) + stub = mock_motors.build_read_stub(addr, length, id_, value, reply=False) + bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + bus.connect(handshake=False) + + if raise_on_error: + with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): + bus._read(addr, length, id_, raise_on_error=raise_on_error) + else: + _, read_comm, _ = bus._read(addr, length, id_, raise_on_error=raise_on_error) + assert read_comm == scs.COMM_RX_TIMEOUT + + assert mock_motors.stubs[stub].called + + +@pytest.mark.parametrize( + "addr, length, id_, value", + [ + (0, 1, 1, 2), + (10, 2, 2, 999), + (42, 4, 3, 1337), + ], +) +def test__write(addr, length, id_, value, mock_motors, dummy_motors): + stub = mock_motors.build_write_stub(addr, length, id_, value) + bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + ) + bus.connect(handshake=False) + + comm, error = bus._write(addr, length, id_, value) + + assert mock_motors.stubs[stub].called + assert comm == scs.COMM_SUCCESS + assert error == 0 + + +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__write_error(raise_on_error, mock_motors, dummy_motors): + addr, length, id_, value, error = (10, 4, 1, 1337, scs.ERRBIT_VOLTAGE) + stub = mock_motors.build_write_stub(addr, length, id_, value, error=error) + bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + + if raise_on_error: + with pytest.raises(RuntimeError, match=re.escape("[RxPacketError] Input voltage error!")): + bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + else: + _, write_error = bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + assert write_error == error + + assert mock_motors.stubs[stub].called + + +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__write_comm(raise_on_error, mock_motors, dummy_motors): + addr, length, id_, value = (10, 4, 1, 1337) + stub = mock_motors.build_write_stub(addr, length, id_, value, reply=False) + bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + + if raise_on_error: + with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): + bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + else: + write_comm, _ = bus._write(addr, length, id_, value, raise_on_error=raise_on_error) + assert write_comm == scs.COMM_RX_TIMEOUT + + assert mock_motors.stubs[stub].called + + +@pytest.mark.parametrize( + "addr, length, ids_values", + [ + (0, 1, {1: 4}), + (10, 2, {1: 1337, 2: 42}), + (42, 4, {1: 1337, 2: 42, 3: 4016}), + ], + ids=["1 motor", "2 motors", "3 motors"], +) +def test__sync_read(addr, length, ids_values, mock_motors, dummy_motors): + stub = mock_motors.build_sync_read_stub(addr, length, ids_values) + bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + + read_values, _ = bus._sync_read(addr, length, list(ids_values)) + + assert mock_motors.stubs[stub].called + assert read_values == ids_values + + +@pytest.mark.parametrize("raise_on_error", (True, False)) +def test__sync_read_comm(raise_on_error, mock_motors, dummy_motors): + addr, length, ids_values = (10, 4, {1: 1337}) + stub = mock_motors.build_sync_read_stub(addr, length, ids_values, reply=False) + bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + + if raise_on_error: + with pytest.raises(ConnectionError, match=re.escape("[TxRxResult] There is no status packet!")): + bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) + else: + _, read_comm = bus._sync_read(addr, length, list(ids_values), raise_on_error=raise_on_error) + assert read_comm == scs.COMM_RX_TIMEOUT + + assert mock_motors.stubs[stub].called + + +@pytest.mark.parametrize( + "addr, length, ids_values", + [ + (0, 1, {1: 4}), + (10, 2, {1: 1337, 2: 42}), + (42, 4, {1: 1337, 2: 42, 3: 4016}), + ], + ids=["1 motor", "2 motors", "3 motors"], +) +def test__sync_write(addr, length, ids_values, mock_motors, dummy_motors): + stub = mock_motors.build_sync_write_stub(addr, length, ids_values) + bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + + comm = bus._sync_write(addr, length, ids_values) + + assert mock_motors.stubs[stub].wait_called() + assert comm == scs.COMM_SUCCESS + + +def test_is_calibrated(mock_motors, dummy_motors, dummy_calibration): + mins_stubs, maxes_stubs, homings_stubs = [], [], [] + for cal in dummy_calibration.values(): + mins_stubs.append( + mock_motors.build_read_stub( + *STS_SMS_SERIES_CONTROL_TABLE["Min_Position_Limit"], cal.id, cal.range_min + ) + ) + maxes_stubs.append( + mock_motors.build_read_stub( + *STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], cal.id, cal.range_max + ) + ) + homings_stubs.append( + mock_motors.build_read_stub( + *STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], + cal.id, + encode_sign_magnitude(cal.homing_offset, 11), + ) + ) + + bus = FeetechMotorsBus( + port=mock_motors.port, + motors=dummy_motors, + calibration=dummy_calibration, + ) + bus.connect(handshake=False) + + is_calibrated = bus.is_calibrated + + assert is_calibrated + assert all(mock_motors.stubs[stub].called for stub in mins_stubs) + assert all(mock_motors.stubs[stub].called for stub in maxes_stubs) + assert all(mock_motors.stubs[stub].called for stub in homings_stubs) + + +def test_reset_calibration(mock_motors, dummy_motors): + write_homing_stubs = [] + write_mins_stubs = [] + write_maxes_stubs = [] + for motor in dummy_motors.values(): + write_homing_stubs.append( + mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], motor.id, 0) + ) + write_mins_stubs.append( + mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Min_Position_Limit"], motor.id, 0) + ) + write_maxes_stubs.append( + mock_motors.build_write_stub(*STS_SMS_SERIES_CONTROL_TABLE["Max_Position_Limit"], motor.id, 4095) + ) + + bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + + bus.reset_calibration() + + assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs) + assert all(mock_motors.stubs[stub].called for stub in write_mins_stubs) + assert all(mock_motors.stubs[stub].called for stub in write_maxes_stubs) + + +def test_set_half_turn_homings(mock_motors, dummy_motors): + """ + For this test, we assume that the homing offsets are already 0 such that + Present_Position == Actual_Position + """ + current_positions = { + 1: 1337, + 2: 42, + 3: 3672, + } + expected_homings = { + 1: -710, # 1337 - 2047 + 2: -2005, # 42 - 2047 + 3: 1625, # 3672 - 2047 + } + read_pos_stub = mock_motors.build_sync_read_stub( + *STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], current_positions + ) + write_homing_stubs = [] + for id_, homing in expected_homings.items(): + encoded_homing = encode_sign_magnitude(homing, 11) + stub = mock_motors.build_write_stub( + *STS_SMS_SERIES_CONTROL_TABLE["Homing_Offset"], id_, encoded_homing + ) + write_homing_stubs.append(stub) + + bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + bus.reset_calibration = MagicMock() + + bus.set_half_turn_homings() + + bus.reset_calibration.assert_called_once() + assert mock_motors.stubs[read_pos_stub].called + assert all(mock_motors.stubs[stub].called for stub in write_homing_stubs) + + +def test_record_ranges_of_motion(mock_motors, dummy_motors): + positions = { + 1: [351, 42, 1337], + 2: [28, 3600, 2444], + 3: [4002, 2999, 146], + } + expected_mins = { + "dummy_1": 42, + "dummy_2": 28, + "dummy_3": 146, + } + expected_maxes = { + "dummy_1": 1337, + "dummy_2": 3600, + "dummy_3": 4002, + } + stub = mock_motors.build_sequential_sync_read_stub( + *STS_SMS_SERIES_CONTROL_TABLE["Present_Position"], positions + ) + with patch("lerobot.common.motors.motors_bus.enter_pressed", side_effect=[False, True]): + bus = FeetechMotorsBus(port=mock_motors.port, motors=dummy_motors) + bus.connect(handshake=False) + + mins, maxes = bus.record_ranges_of_motion(display_values=False) + + assert mock_motors.stubs[stub].calls == 3 + assert mins == expected_mins + assert maxes == expected_maxes diff --git a/tests/motors/test_motors.py b/tests/motors/test_motors.py deleted file mode 100644 index da7a5c543..000000000 --- a/tests/motors/test_motors.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Tests for physical motors and their mocked versions. -If the physical motors are not connected to the computer, or not working, -the test will be skipped. - -Example of running a specific test: -```bash -pytest -sx tests/test_motors.py::test_find_port -pytest -sx tests/test_motors.py::test_motors_bus -``` - -Example of running test on real dynamixel motors connected to the computer: -```bash -pytest -sx 'tests/test_motors.py::test_motors_bus[dynamixel-False]' -``` - -Example of running test on a mocked version of dynamixel motors: -```bash -pytest -sx 'tests/test_motors.py::test_motors_bus[dynamixel-True]' -``` -""" - -# TODO(rcadene): measure fps in nightly? -# TODO(rcadene): test logs -# TODO(rcadene): test calibration -# TODO(rcadene): add compatibility with other motors bus - -import time - -import numpy as np -import pytest - -from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError -from lerobot.scripts.find_motors_bus_port import find_port -from tests.utils import TEST_MOTOR_TYPES, make_motors_bus, require_motor - - -@pytest.mark.parametrize("motor_type, mock", TEST_MOTOR_TYPES) -@require_motor -def test_find_port(request, motor_type, mock): - if mock: - request.getfixturevalue("patch_builtins_input") - with pytest.raises(OSError): - find_port() - else: - find_port() - - -@pytest.mark.parametrize("motor_type, mock", TEST_MOTOR_TYPES) -@require_motor -def test_configure_motors_all_ids_1(request, motor_type, mock): - if mock: - request.getfixturevalue("patch_builtins_input") - - if motor_type == "dynamixel": - # see X_SERIES_BAUDRATE_TABLE - smaller_baudrate = 9_600 - smaller_baudrate_value = 0 - elif motor_type == "feetech": - # see SCS_SERIES_BAUDRATE_TABLE - smaller_baudrate = 19_200 - smaller_baudrate_value = 7 - else: - raise ValueError(motor_type) - - input("Are you sure you want to re-configure the motors? Press enter to continue...") - # This test expect the configuration was already correct. - motors_bus = make_motors_bus(motor_type, mock=mock) - motors_bus.connect() - motors_bus.write("Baud_Rate", [smaller_baudrate_value] * len(motors_bus.motors)) - - motors_bus.set_bus_baudrate(smaller_baudrate) - motors_bus.write("ID", [1] * len(motors_bus.motors)) - del motors_bus - - # Test configure - motors_bus = make_motors_bus(motor_type, mock=mock) - motors_bus.connect() - assert motors_bus.are_motors_configured() - del motors_bus - - -@pytest.mark.parametrize("motor_type, mock", TEST_MOTOR_TYPES) -@require_motor -def test_motors_bus(request, motor_type, mock): - if mock: - request.getfixturevalue("patch_builtins_input") - - motors_bus = make_motors_bus(motor_type, mock=mock) - - # Test reading and writing before connecting raises an error - with pytest.raises(RobotDeviceNotConnectedError): - motors_bus.read("Torque_Enable") - with pytest.raises(RobotDeviceNotConnectedError): - motors_bus.write("Torque_Enable", 1) - with pytest.raises(RobotDeviceNotConnectedError): - motors_bus.disconnect() - - # Test deleting the object without connecting first - del motors_bus - - # Test connecting - motors_bus = make_motors_bus(motor_type, mock=mock) - motors_bus.connect() - - # Test connecting twice raises an error - with pytest.raises(RobotDeviceAlreadyConnectedError): - motors_bus.connect() - - # Test disabling torque and reading torque on all motors - motors_bus.write("Torque_Enable", 0) - values = motors_bus.read("Torque_Enable") - assert isinstance(values, np.ndarray) - assert len(values) == len(motors_bus.motors) - assert (values == 0).all() - - # Test writing torque on a specific motor - motors_bus.write("Torque_Enable", 1, "gripper") - - # Test reading torque from this specific motor. It is now 1 - values = motors_bus.read("Torque_Enable", "gripper") - assert len(values) == 1 - assert values[0] == 1 - - # Test reading torque from all motors. It is 1 for the specific motor, - # and 0 on the others. - values = motors_bus.read("Torque_Enable") - gripper_index = motors_bus.motor_names.index("gripper") - assert values[gripper_index] == 1 - assert values.sum() == 1 # gripper is the only motor to have torque 1 - - # Test writing torque on all motors and it is 1 for all. - motors_bus.write("Torque_Enable", 1) - values = motors_bus.read("Torque_Enable") - assert (values == 1).all() - - # Test ordering the motors to move slightly (+1 value among 4096) and this move - # can be executed and seen by the motor position sensor - values = motors_bus.read("Present_Position") - motors_bus.write("Goal_Position", values + 1) - # Give time for the motors to move to the goal position - time.sleep(1) - new_values = motors_bus.read("Present_Position") - assert (new_values == values).all() diff --git a/tests/motors/test_motors_bus.py b/tests/motors/test_motors_bus.py new file mode 100644 index 000000000..78b7a47da --- /dev/null +++ b/tests/motors/test_motors_bus.py @@ -0,0 +1,342 @@ +import re +from unittest.mock import patch + +import pytest + +from lerobot.common.motors.motors_bus import ( + Motor, + MotorNormMode, + assert_same_address, + get_address, + get_ctrl_table, +) +from tests.mocks.mock_motors_bus import ( + DUMMY_CTRL_TABLE_1, + DUMMY_CTRL_TABLE_2, + DUMMY_MODEL_CTRL_TABLE, + MockMotorsBus, +) + + +@pytest.fixture +def dummy_motors() -> dict[str, Motor]: + return { + "dummy_1": Motor(1, "model_2", MotorNormMode.RANGE_M100_100), + "dummy_2": Motor(2, "model_3", MotorNormMode.RANGE_M100_100), + "dummy_3": Motor(3, "model_2", MotorNormMode.RANGE_0_100), + } + + +def test_get_ctrl_table(): + model = "model_1" + ctrl_table = get_ctrl_table(DUMMY_MODEL_CTRL_TABLE, model) + assert ctrl_table == DUMMY_CTRL_TABLE_1 + + +def test_get_ctrl_table_error(): + model = "model_99" + with pytest.raises(KeyError, match=f"Control table for {model=} not found."): + get_ctrl_table(DUMMY_MODEL_CTRL_TABLE, model) + + +def test_get_address(): + addr, n_bytes = get_address(DUMMY_MODEL_CTRL_TABLE, "model_1", "Firmware_Version") + assert addr == 0 + assert n_bytes == 1 + + +def test_get_address_error(): + model = "model_1" + data_name = "Lock" + with pytest.raises(KeyError, match=f"Address for '{data_name}' not found in {model} control table."): + get_address(DUMMY_MODEL_CTRL_TABLE, "model_1", data_name) + + +def test_assert_same_address(): + models = ["model_1", "model_2"] + assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Present_Position") + + +def test_assert_same_length_different_addresses(): + models = ["model_1", "model_2"] + with pytest.raises( + NotImplementedError, + match=re.escape("At least two motor models use a different address"), + ): + assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Model_Number") + + +def test_assert_same_address_different_length(): + models = ["model_1", "model_2"] + with pytest.raises( + NotImplementedError, + match=re.escape("At least two motor models use a different bytes representation"), + ): + assert_same_address(DUMMY_MODEL_CTRL_TABLE, models, "Goal_Position") + + +def test__serialize_data_invalid_length(): + bus = MockMotorsBus("", {}) + with pytest.raises(NotImplementedError): + bus._serialize_data(100, 3) + + +def test__serialize_data_negative_numbers(): + bus = MockMotorsBus("", {}) + with pytest.raises(ValueError): + bus._serialize_data(-1, 1) + + +def test__serialize_data_large_number(): + bus = MockMotorsBus("", {}) + with pytest.raises(ValueError): + bus._serialize_data(2**32, 4) # 4-byte max is 0xFFFFFFFF + + +@pytest.mark.parametrize( + "data_name, id_, value", + [ + ("Firmware_Version", 1, 14), + ("Model_Number", 1, 5678), + ("Present_Position", 2, 1337), + ("Present_Velocity", 3, 42), + ], +) +def test_read(data_name, id_, value, dummy_motors): + bus = MockMotorsBus("/dev/dummy-port", dummy_motors) + bus.connect(handshake=False) + addr, length = DUMMY_CTRL_TABLE_2[data_name] + + with ( + patch.object(MockMotorsBus, "_read", return_value=(value, 0, 0)) as mock__read, + patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign, + patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize, + ): + returned_value = bus.read(data_name, f"dummy_{id_}") + + assert returned_value == value + mock__read.assert_called_once_with( + addr, + length, + id_, + num_retry=0, + raise_on_error=True, + err_msg=f"Failed to read '{data_name}' on {id_=} after 1 tries.", + ) + mock__decode_sign.assert_called_once_with(data_name, {id_: value}) + if data_name in bus.normalized_data: + mock__normalize.assert_called_once_with({id_: value}) + + +@pytest.mark.parametrize( + "data_name, id_, value", + [ + ("Goal_Position", 1, 1337), + ("Goal_Velocity", 2, 3682), + ("Lock", 3, 1), + ], +) +def test_write(data_name, id_, value, dummy_motors): + bus = MockMotorsBus("/dev/dummy-port", dummy_motors) + bus.connect(handshake=False) + addr, length = DUMMY_CTRL_TABLE_2[data_name] + + with ( + patch.object(MockMotorsBus, "_write", return_value=(0, 0)) as mock__write, + patch.object(MockMotorsBus, "_encode_sign", return_value={id_: value}) as mock__encode_sign, + patch.object(MockMotorsBus, "_unnormalize", return_value={id_: value}) as mock__unnormalize, + ): + bus.write(data_name, f"dummy_{id_}", value) + + mock__write.assert_called_once_with( + addr, + length, + id_, + value, + num_retry=0, + raise_on_error=True, + err_msg=f"Failed to write '{data_name}' on {id_=} with '{value}' after 1 tries.", + ) + mock__encode_sign.assert_called_once_with(data_name, {id_: value}) + if data_name in bus.normalized_data: + mock__unnormalize.assert_called_once_with({id_: value}) + + +@pytest.mark.parametrize( + "data_name, id_, value", + [ + ("Firmware_Version", 1, 14), + ("Model_Number", 1, 5678), + ("Present_Position", 2, 1337), + ("Present_Velocity", 3, 42), + ], +) +def test_sync_read_by_str(data_name, id_, value, dummy_motors): + bus = MockMotorsBus("/dev/dummy-port", dummy_motors) + bus.connect(handshake=False) + addr, length = DUMMY_CTRL_TABLE_2[data_name] + ids = [id_] + expected_value = {f"dummy_{id_}": value} + + with ( + patch.object(MockMotorsBus, "_sync_read", return_value=({id_: value}, 0)) as mock__sync_read, + patch.object(MockMotorsBus, "_decode_sign", return_value={id_: value}) as mock__decode_sign, + patch.object(MockMotorsBus, "_normalize", return_value={id_: value}) as mock__normalize, + ): + returned_dict = bus.sync_read(data_name, f"dummy_{id_}") + + assert returned_dict == expected_value + mock__sync_read.assert_called_once_with( + addr, + length, + ids, + num_retry=0, + raise_on_error=True, + err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.", + ) + mock__decode_sign.assert_called_once_with(data_name, {id_: value}) + if data_name in bus.normalized_data: + mock__normalize.assert_called_once_with({id_: value}) + + +@pytest.mark.parametrize( + "data_name, ids_values", + [ + ("Model_Number", {1: 5678}), + ("Present_Position", {1: 1337, 2: 42}), + ("Present_Velocity", {1: 1337, 2: 42, 3: 4016}), + ], + ids=["1 motor", "2 motors", "3 motors"], +) +def test_sync_read_by_list(data_name, ids_values, dummy_motors): + bus = MockMotorsBus("/dev/dummy-port", dummy_motors) + bus.connect(handshake=False) + addr, length = DUMMY_CTRL_TABLE_2[data_name] + ids = list(ids_values) + expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()} + + with ( + patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read, + patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign, + patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize, + ): + returned_dict = bus.sync_read(data_name, [f"dummy_{id_}" for id_ in ids]) + + assert returned_dict == expected_values + mock__sync_read.assert_called_once_with( + addr, + length, + ids, + num_retry=0, + raise_on_error=True, + err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.", + ) + mock__decode_sign.assert_called_once_with(data_name, ids_values) + if data_name in bus.normalized_data: + mock__normalize.assert_called_once_with(ids_values) + + +@pytest.mark.parametrize( + "data_name, ids_values", + [ + ("Model_Number", {1: 5678, 2: 5799, 3: 5678}), + ("Present_Position", {1: 1337, 2: 42, 3: 4016}), + ("Goal_Position", {1: 4008, 2: 199, 3: 3446}), + ], + ids=["Model_Number", "Present_Position", "Goal_Position"], +) +def test_sync_read_by_none(data_name, ids_values, dummy_motors): + bus = MockMotorsBus("/dev/dummy-port", dummy_motors) + bus.connect(handshake=False) + addr, length = DUMMY_CTRL_TABLE_2[data_name] + ids = list(ids_values) + expected_values = {f"dummy_{id_}": val for id_, val in ids_values.items()} + + with ( + patch.object(MockMotorsBus, "_sync_read", return_value=(ids_values, 0)) as mock__sync_read, + patch.object(MockMotorsBus, "_decode_sign", return_value=ids_values) as mock__decode_sign, + patch.object(MockMotorsBus, "_normalize", return_value=ids_values) as mock__normalize, + ): + returned_dict = bus.sync_read(data_name) + + assert returned_dict == expected_values + mock__sync_read.assert_called_once_with( + addr, + length, + ids, + num_retry=0, + raise_on_error=True, + err_msg=f"Failed to sync read '{data_name}' on {ids=} after 1 tries.", + ) + mock__decode_sign.assert_called_once_with(data_name, ids_values) + if data_name in bus.normalized_data: + mock__normalize.assert_called_once_with(ids_values) + + +@pytest.mark.parametrize( + "data_name, value", + [ + ("Goal_Position", 500), + ("Goal_Velocity", 4010), + ("Lock", 0), + ], +) +def test_sync_write_by_single_value(data_name, value, dummy_motors): + bus = MockMotorsBus("/dev/dummy-port", dummy_motors) + bus.connect(handshake=False) + addr, length = DUMMY_CTRL_TABLE_2[data_name] + ids_values = {m.id: value for m in dummy_motors.values()} + + with ( + patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write, + patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign, + patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize, + ): + bus.sync_write(data_name, value) + + mock__sync_write.assert_called_once_with( + addr, + length, + ids_values, + num_retry=0, + raise_on_error=True, + err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.", + ) + mock__encode_sign.assert_called_once_with(data_name, ids_values) + if data_name in bus.normalized_data: + mock__unnormalize.assert_called_once_with(ids_values) + + +@pytest.mark.parametrize( + "data_name, ids_values", + [ + ("Goal_Position", {1: 1337, 2: 42, 3: 4016}), + ("Goal_Velocity", {1: 50, 2: 83, 3: 2777}), + ("Lock", {1: 0, 2: 0, 3: 1}), + ], + ids=["Goal_Position", "Goal_Velocity", "Lock"], +) +def test_sync_write_by_value_dict(data_name, ids_values, dummy_motors): + bus = MockMotorsBus("/dev/dummy-port", dummy_motors) + bus.connect(handshake=False) + addr, length = DUMMY_CTRL_TABLE_2[data_name] + values = {f"dummy_{id_}": val for id_, val in ids_values.items()} + + with ( + patch.object(MockMotorsBus, "_sync_write", return_value=(ids_values, 0)) as mock__sync_write, + patch.object(MockMotorsBus, "_encode_sign", return_value=ids_values) as mock__encode_sign, + patch.object(MockMotorsBus, "_unnormalize", return_value=ids_values) as mock__unnormalize, + ): + bus.sync_write(data_name, values) + + mock__sync_write.assert_called_once_with( + addr, + length, + ids_values, + num_retry=0, + raise_on_error=True, + err_msg=f"Failed to sync write '{data_name}' with {ids_values=} after 1 tries.", + ) + mock__encode_sign.assert_called_once_with(data_name, ids_values) + if data_name in bus.normalized_data: + mock__unnormalize.assert_called_once_with(ids_values) diff --git a/tests/optim/test_schedulers.py b/tests/optim/test_schedulers.py index 176376633..d6191dcea 100644 --- a/tests/optim/test_schedulers.py +++ b/tests/optim/test_schedulers.py @@ -37,7 +37,6 @@ def test_diffuser_scheduler(optimizer): "base_lrs": [0.001], "last_epoch": 1, "lr_lambdas": [None], - "verbose": False, } assert scheduler.state_dict() == expected_state_dict @@ -56,7 +55,6 @@ def test_vqbet_scheduler(optimizer): "base_lrs": [0.001], "last_epoch": 1, "lr_lambdas": [None], - "verbose": False, } assert scheduler.state_dict() == expected_state_dict @@ -77,7 +75,6 @@ def test_cosine_decay_with_warmup_scheduler(optimizer): "base_lrs": [0.001], "last_epoch": 1, "lr_lambdas": [None], - "verbose": False, } assert scheduler.state_dict() == expected_state_dict diff --git a/tests/policies/hilserl/test_modeling_classifier.py b/tests/policies/hilserl/test_modeling_classifier.py new file mode 100644 index 000000000..526e1f17d --- /dev/null +++ b/tests/policies/hilserl/test_modeling_classifier.py @@ -0,0 +1,139 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from lerobot.common.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig +from lerobot.common.policies.sac.reward_model.modeling_classifier import ClassifierOutput +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from tests.utils import require_package + + +def test_classifier_output(): + output = ClassifierOutput( + logits=torch.tensor([1, 2, 3]), + probabilities=torch.tensor([0.1, 0.2, 0.3]), + hidden_states=None, + ) + + assert ( + f"{output}" + == "ClassifierOutput(logits=tensor([1, 2, 3]), probabilities=tensor([0.1000, 0.2000, 0.3000]), hidden_states=None)" + ) + + +@require_package("transformers") +def test_binary_classifier_with_default_params(): + from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier + + config = RewardClassifierConfig() + config.input_features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(1,)), + } + config.normalization_mapping = { + "VISUAL": NormalizationMode.IDENTITY, + "REWARD": NormalizationMode.IDENTITY, + } + config.num_cameras = 1 + classifier = Classifier(config) + + batch_size = 10 + + input = { + "observation.image": torch.rand((batch_size, 3, 128, 128)), + "next.reward": torch.randint(low=0, high=2, size=(batch_size,)).float(), + } + + images, labels = classifier.extract_images_and_labels(input) + assert len(images) == 1 + assert images[0].shape == torch.Size([batch_size, 3, 128, 128]) + assert labels.shape == torch.Size([batch_size]) + + output = classifier.predict(images) + + assert output is not None + assert output.logits.size() == torch.Size([batch_size]) + assert not torch.isnan(output.logits).any(), "Tensor contains NaN values" + assert output.probabilities.shape == torch.Size([batch_size]) + assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values" + assert output.hidden_states.shape == torch.Size([batch_size, 256]) + assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values" + + +@require_package("transformers") +def test_multiclass_classifier(): + from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier + + num_classes = 5 + config = RewardClassifierConfig() + config.input_features = { + "observation.image": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), + } + config.output_features = { + "next.reward": PolicyFeature(type=FeatureType.REWARD, shape=(num_classes,)), + } + config.num_cameras = 1 + config.num_classes = num_classes + classifier = Classifier(config) + + batch_size = 10 + + input = { + "observation.image": torch.rand((batch_size, 3, 128, 128)), + "next.reward": torch.rand((batch_size, num_classes)), + } + + images, labels = classifier.extract_images_and_labels(input) + assert len(images) == 1 + assert images[0].shape == torch.Size([batch_size, 3, 128, 128]) + assert labels.shape == torch.Size([batch_size, num_classes]) + + output = classifier.predict(images) + + assert output is not None + assert output.logits.shape == torch.Size([batch_size, num_classes]) + assert not torch.isnan(output.logits).any(), "Tensor contains NaN values" + assert output.probabilities.shape == torch.Size([batch_size, num_classes]) + assert not torch.isnan(output.probabilities).any(), "Tensor contains NaN values" + assert output.hidden_states.shape == torch.Size([batch_size, 256]) + assert not torch.isnan(output.hidden_states).any(), "Tensor contains NaN values" + + +@require_package("transformers") +def test_default_device(): + from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier + + config = RewardClassifierConfig() + assert config.device == "cpu" + + classifier = Classifier(config) + for p in classifier.parameters(): + assert p.device == torch.device("cpu") + + +@require_package("transformers") +def test_explicit_device_setup(): + from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier + + config = RewardClassifierConfig(device="cpu") + assert config.device == "cpu" + + classifier = Classifier(config) + for p in classifier.parameters(): + assert p.device == torch.device("cpu") diff --git a/tests/policies/test_sac_config.py b/tests/policies/test_sac_config.py new file mode 100644 index 000000000..d94ee41e0 --- /dev/null +++ b/tests/policies/test_sac_config.py @@ -0,0 +1,217 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from lerobot.common.policies.sac.configuration_sac import ( + ActorLearnerConfig, + ActorNetworkConfig, + ConcurrencyConfig, + CriticNetworkConfig, + PolicyConfig, + SACConfig, +) +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature + + +def test_sac_config_default_initialization(): + config = SACConfig() + + assert config.normalization_mapping == { + "VISUAL": NormalizationMode.MEAN_STD, + "STATE": NormalizationMode.MIN_MAX, + "ENV": NormalizationMode.MIN_MAX, + "ACTION": NormalizationMode.MIN_MAX, + } + assert config.dataset_stats == { + "observation.image": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + }, + "observation.state": { + "min": [0.0, 0.0], + "max": [1.0, 1.0], + }, + "action": { + "min": [0.0, 0.0, 0.0], + "max": [1.0, 1.0, 1.0], + }, + } + + # Basic parameters + assert config.device == "cpu" + assert config.storage_device == "cpu" + assert config.discount == 0.99 + assert config.temperature_init == 1.0 + assert config.num_critics == 2 + + # Architecture specifics + assert config.vision_encoder_name is None + assert config.freeze_vision_encoder is True + assert config.image_encoder_hidden_dim == 32 + assert config.shared_encoder is True + assert config.num_discrete_actions is None + assert config.image_embedding_pooling_dim == 8 + + # Training parameters + assert config.online_steps == 1000000 + assert config.online_env_seed == 10000 + assert config.online_buffer_capacity == 100000 + assert config.offline_buffer_capacity == 100000 + assert config.async_prefetch is False + assert config.online_step_before_learning == 100 + assert config.policy_update_freq == 1 + + # SAC algorithm parameters + assert config.num_subsample_critics is None + assert config.critic_lr == 3e-4 + assert config.actor_lr == 3e-4 + assert config.temperature_lr == 3e-4 + assert config.critic_target_update_weight == 0.005 + assert config.utd_ratio == 1 + assert config.state_encoder_hidden_dim == 256 + assert config.latent_dim == 256 + assert config.target_entropy is None + assert config.use_backup_entropy is True + assert config.grad_clip_norm == 40.0 + + # Dataset stats defaults + expected_dataset_stats = { + "observation.image": { + "mean": [0.485, 0.456, 0.406], + "std": [0.229, 0.224, 0.225], + }, + "observation.state": { + "min": [0.0, 0.0], + "max": [1.0, 1.0], + }, + "action": { + "min": [0.0, 0.0, 0.0], + "max": [1.0, 1.0, 1.0], + }, + } + assert config.dataset_stats == expected_dataset_stats + + # Critic network configuration + assert config.critic_network_kwargs.hidden_dims == [256, 256] + assert config.critic_network_kwargs.activate_final is True + assert config.critic_network_kwargs.final_activation is None + + # Actor network configuration + assert config.actor_network_kwargs.hidden_dims == [256, 256] + assert config.actor_network_kwargs.activate_final is True + + # Policy configuration + assert config.policy_kwargs.use_tanh_squash is True + assert config.policy_kwargs.std_min == 1e-5 + assert config.policy_kwargs.std_max == 10.0 + assert config.policy_kwargs.init_final == 0.05 + + # Discrete critic network configuration + assert config.discrete_critic_network_kwargs.hidden_dims == [256, 256] + assert config.discrete_critic_network_kwargs.activate_final is True + assert config.discrete_critic_network_kwargs.final_activation is None + + # Actor learner configuration + assert config.actor_learner_config.learner_host == "127.0.0.1" + assert config.actor_learner_config.learner_port == 50051 + assert config.actor_learner_config.policy_parameters_push_frequency == 4 + + # Concurrency configuration + assert config.concurrency.actor == "threads" + assert config.concurrency.learner == "threads" + + assert isinstance(config.actor_network_kwargs, ActorNetworkConfig) + assert isinstance(config.critic_network_kwargs, CriticNetworkConfig) + assert isinstance(config.policy_kwargs, PolicyConfig) + assert isinstance(config.actor_learner_config, ActorLearnerConfig) + assert isinstance(config.concurrency, ConcurrencyConfig) + + +def test_critic_network_kwargs(): + config = CriticNetworkConfig() + assert config.hidden_dims == [256, 256] + assert config.activate_final is True + assert config.final_activation is None + + +def test_actor_network_kwargs(): + config = ActorNetworkConfig() + assert config.hidden_dims == [256, 256] + assert config.activate_final is True + + +def test_policy_kwargs(): + config = PolicyConfig() + assert config.use_tanh_squash is True + assert config.std_min == 1e-5 + assert config.std_max == 10.0 + assert config.init_final == 0.05 + + +def test_actor_learner_config(): + config = ActorLearnerConfig() + assert config.learner_host == "127.0.0.1" + assert config.learner_port == 50051 + assert config.policy_parameters_push_frequency == 4 + + +def test_concurrency_config(): + config = ConcurrencyConfig() + assert config.actor == "threads" + assert config.learner == "threads" + + +def test_sac_config_custom_initialization(): + config = SACConfig( + device="cpu", + discount=0.95, + temperature_init=0.5, + num_critics=3, + ) + + assert config.device == "cpu" + assert config.discount == 0.95 + assert config.temperature_init == 0.5 + assert config.num_critics == 3 + + +def test_validate_features(): + config = SACConfig( + input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, + output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, + ) + config.validate_features() + + +def test_validate_features_missing_observation(): + config = SACConfig( + input_features={"wrong_key": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, + output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, + ) + with pytest.raises( + ValueError, match="You must provide either 'observation.state' or an image observation" + ): + config.validate_features() + + +def test_validate_features_missing_action(): + config = SACConfig( + input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,))}, + output_features={"wrong_key": PolicyFeature(type=FeatureType.ACTION, shape=(3,))}, + ) + with pytest.raises(ValueError, match="You must provide 'action' in the output features"): + config.validate_features() diff --git a/tests/policies/test_sac_policy.py b/tests/policies/test_sac_policy.py new file mode 100644 index 000000000..e4e2dd8a9 --- /dev/null +++ b/tests/policies/test_sac_policy.py @@ -0,0 +1,541 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import pytest +import torch +from torch import Tensor, nn + +from lerobot.common.policies.sac.configuration_sac import SACConfig +from lerobot.common.policies.sac.modeling_sac import MLP, SACPolicy +from lerobot.common.utils.random_utils import seeded_context, set_seed +from lerobot.configs.types import FeatureType, PolicyFeature + +try: + import transformers # noqa: F401 + + TRANSFORMERS_AVAILABLE = True +except ImportError: + TRANSFORMERS_AVAILABLE = False + + +@pytest.fixture(autouse=True) +def set_random_seed(): + seed = 42 + set_seed(seed) + + +def test_mlp_with_default_args(): + mlp = MLP(input_dim=10, hidden_dims=[256, 256]) + + x = torch.randn(10) + y = mlp(x) + assert y.shape == (256,) + + +def test_mlp_with_batch_dim(): + mlp = MLP(input_dim=10, hidden_dims=[256, 256]) + x = torch.randn(2, 10) + y = mlp(x) + assert y.shape == (2, 256) + + +def test_forward_with_empty_hidden_dims(): + mlp = MLP(input_dim=10, hidden_dims=[]) + x = torch.randn(1, 10) + assert mlp(x).shape == (1, 10) + + +def test_mlp_with_dropout(): + mlp = MLP(input_dim=10, hidden_dims=[256, 256, 11], dropout_rate=0.1) + x = torch.randn(1, 10) + y = mlp(x) + assert y.shape == (1, 11) + + drop_out_layers_count = sum(isinstance(layer, nn.Dropout) for layer in mlp.net) + assert drop_out_layers_count == 2 + + +def test_mlp_with_custom_final_activation(): + mlp = MLP(input_dim=10, hidden_dims=[256, 256], final_activation=torch.nn.Tanh()) + x = torch.randn(1, 10) + y = mlp(x) + assert y.shape == (1, 256) + assert (y >= -1).all() and (y <= 1).all() + + +def test_sac_policy_with_default_args(): + with pytest.raises(ValueError, match="should be an instance of class `PreTrainedConfig`"): + SACPolicy() + + +def create_dummy_state(batch_size: int, state_dim: int = 10) -> Tensor: + return { + "observation.state": torch.randn(batch_size, state_dim), + } + + +def create_dummy_with_visual_input(batch_size: int, state_dim: int = 10) -> Tensor: + return { + "observation.image": torch.randn(batch_size, 3, 84, 84), + "observation.state": torch.randn(batch_size, state_dim), + } + + +def create_dummy_action(batch_size: int, action_dim: int = 10) -> Tensor: + return torch.randn(batch_size, action_dim) + + +def create_default_train_batch( + batch_size: int = 8, state_dim: int = 10, action_dim: int = 10 +) -> dict[str, Tensor]: + return { + "action": create_dummy_action(batch_size, action_dim), + "reward": torch.randn(batch_size), + "state": create_dummy_state(batch_size, state_dim), + "next_state": create_dummy_state(batch_size, state_dim), + "done": torch.randn(batch_size), + } + + +def create_train_batch_with_visual_input( + batch_size: int = 8, state_dim: int = 10, action_dim: int = 10 +) -> dict[str, Tensor]: + return { + "action": create_dummy_action(batch_size, action_dim), + "reward": torch.randn(batch_size), + "state": create_dummy_with_visual_input(batch_size, state_dim), + "next_state": create_dummy_with_visual_input(batch_size, state_dim), + "done": torch.randn(batch_size), + } + + +def create_observation_batch(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: + return { + "observation.state": torch.randn(batch_size, state_dim), + } + + +def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: int = 10) -> dict[str, Tensor]: + return { + "observation.state": torch.randn(batch_size, state_dim), + "observation.image": torch.randn(batch_size, 3, 84, 84), + } + + +def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]: + """Create optimizers for the SAC policy.""" + optimizer_actor = torch.optim.Adam( + # Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient + params=[ + p + for n, p in policy.actor.named_parameters() + if not policy.config.shared_encoder or not n.startswith("encoder") + ], + lr=policy.config.actor_lr, + ) + optimizer_critic = torch.optim.Adam( + params=policy.critic_ensemble.parameters(), + lr=policy.config.critic_lr, + ) + optimizer_temperature = torch.optim.Adam( + params=[policy.log_alpha], + lr=policy.config.critic_lr, + ) + + optimizers = { + "actor": optimizer_actor, + "critic": optimizer_critic, + "temperature": optimizer_temperature, + } + + if has_discrete_action: + optimizers["discrete_critic"] = torch.optim.Adam( + params=policy.discrete_critic.parameters(), + lr=policy.config.critic_lr, + ) + + return optimizers + + +def create_default_config( + state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False +) -> SACConfig: + action_dim = continuous_action_dim + if has_discrete_action: + action_dim += 1 + + config = SACConfig( + input_features={"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))}, + output_features={"action": PolicyFeature(type=FeatureType.ACTION, shape=(continuous_action_dim,))}, + dataset_stats={ + "observation.state": { + "min": [0.0] * state_dim, + "max": [1.0] * state_dim, + }, + "action": { + "min": [0.0] * continuous_action_dim, + "max": [1.0] * continuous_action_dim, + }, + }, + ) + config.validate_features() + return config + + +def create_config_with_visual_input( + state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False +) -> SACConfig: + config = create_default_config( + state_dim=state_dim, + continuous_action_dim=continuous_action_dim, + has_discrete_action=has_discrete_action, + ) + config.input_features["observation.image"] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84)) + config.dataset_stats["observation.image"] = { + "mean": torch.randn(3, 1, 1), + "std": torch.randn(3, 1, 1), + } + + # Let make tests a little bit faster + config.state_encoder_hidden_dim = 32 + config.latent_dim = 32 + + config.validate_features() + return config + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) +def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int): + batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim) + config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) + + policy = SACPolicy(config=config) + policy.train() + + optimizers = make_optimizers(policy) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + assert actor_loss.item() is not None + assert actor_loss.shape == () + + actor_loss.backward() + optimizers["actor"].step() + + temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] + assert temperature_loss.item() is not None + assert temperature_loss.shape == () + + temperature_loss.backward() + optimizers["temperature"].step() + + policy.eval() + with torch.no_grad(): + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + selected_action = policy.select_action(observation_batch) + assert selected_action.shape == (batch_size, action_dim) + + +@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)]) +def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int): + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + policy = SACPolicy(config=config) + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=action_dim + ) + + policy.train() + + optimizers = make_optimizers(policy) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + assert actor_loss.item() is not None + assert actor_loss.shape == () + + actor_loss.backward() + optimizers["actor"].step() + + temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] + assert temperature_loss.item() is not None + assert temperature_loss.shape == () + + temperature_loss.backward() + optimizers["temperature"].step() + + policy.eval() + with torch.no_grad(): + observation_batch = create_observation_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim + ) + selected_action = policy.select_action(observation_batch) + assert selected_action.shape == (batch_size, action_dim) + + +# Let's check best candidates for pretrained encoders +@pytest.mark.parametrize( + "batch_size,state_dim,action_dim,vision_encoder_name", + [(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")], +) +@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed") +def test_sac_policy_with_pretrained_encoder( + batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str +): + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + config.vision_encoder_name = vision_encoder_name + policy = SACPolicy(config=config) + policy.train() + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=action_dim + ) + + optimizers = make_optimizers(policy) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + assert actor_loss.item() is not None + assert actor_loss.shape == () + + +def test_sac_policy_with_shared_encoder(): + batch_size = 2 + action_dim = 10 + state_dim = 10 + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + config.shared_encoder = True + + policy = SACPolicy(config=config) + policy.train() + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=action_dim + ) + + policy.train() + + optimizers = make_optimizers(policy) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + assert actor_loss.item() is not None + assert actor_loss.shape == () + + actor_loss.backward() + optimizers["actor"].step() + + +def test_sac_policy_with_discrete_critic(): + batch_size = 2 + continuous_action_dim = 9 + full_action_dim = continuous_action_dim + 1 # the last action is discrete + state_dim = 10 + config = create_config_with_visual_input( + state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True + ) + + num_discrete_actions = 5 + config.num_discrete_actions = num_discrete_actions + + policy = SACPolicy(config=config) + policy.train() + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim + ) + + policy.train() + + optimizers = make_optimizers(policy, has_discrete_action=True) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"] + assert discrete_critic_loss.item() is not None + assert discrete_critic_loss.shape == () + discrete_critic_loss.backward() + optimizers["discrete_critic"].step() + + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + assert actor_loss.item() is not None + assert actor_loss.shape == () + + actor_loss.backward() + optimizers["actor"].step() + + policy.eval() + with torch.no_grad(): + observation_batch = create_observation_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim + ) + selected_action = policy.select_action(observation_batch) + assert selected_action.shape == (batch_size, full_action_dim) + + discrete_actions = selected_action[:, -1].long() + discrete_action_values = set(discrete_actions.tolist()) + + assert all(action in range(num_discrete_actions) for action in discrete_action_values), ( + f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})" + ) + + +def test_sac_policy_with_default_entropy(): + config = create_default_config(continuous_action_dim=10, state_dim=10) + policy = SACPolicy(config=config) + assert policy.target_entropy == -5.0 + + +def test_sac_policy_default_target_entropy_with_discrete_action(): + config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True) + policy = SACPolicy(config=config) + assert policy.target_entropy == -3.0 + + +def test_sac_policy_with_predefined_entropy(): + config = create_default_config(state_dim=10, continuous_action_dim=6) + config.target_entropy = -3.5 + + policy = SACPolicy(config=config) + assert policy.target_entropy == pytest.approx(-3.5) + + +def test_sac_policy_update_temperature(): + config = create_default_config(continuous_action_dim=10, state_dim=10) + policy = SACPolicy(config=config) + + assert policy.temperature == pytest.approx(1.0) + policy.log_alpha.data = torch.tensor([math.log(0.1)]) + policy.update_temperature() + assert policy.temperature == pytest.approx(0.1) + + +def test_sac_policy_update_target_network(): + config = create_default_config(state_dim=10, continuous_action_dim=6) + config.critic_target_update_weight = 1.0 + + policy = SACPolicy(config=config) + policy.train() + + for p in policy.critic_ensemble.parameters(): + p.data = torch.ones_like(p.data) + + policy.update_target_networks() + for p in policy.critic_target.parameters(): + assert torch.allclose(p.data, torch.ones_like(p.data)), ( + f"Target network {p.data} is not equal to {torch.ones_like(p.data)}" + ) + + +@pytest.mark.parametrize("num_critics", [1, 3]) +def test_sac_policy_with_critics_number_of_heads(num_critics: int): + batch_size = 2 + action_dim = 10 + state_dim = 10 + config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim) + config.num_critics = num_critics + + policy = SACPolicy(config=config) + policy.train() + + assert len(policy.critic_ensemble.critics) == num_critics + + batch = create_train_batch_with_visual_input( + batch_size=batch_size, state_dim=state_dim, action_dim=action_dim + ) + + policy.train() + + optimizers = make_optimizers(policy) + + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + assert cirtic_loss.item() is not None + assert cirtic_loss.shape == () + cirtic_loss.backward() + optimizers["critic"].step() + + +def test_sac_policy_save_and_load(tmp_path): + root = tmp_path / "test_sac_save_and_load" + + state_dim = 10 + action_dim = 10 + batch_size = 2 + + config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim) + policy = SACPolicy(config=config) + policy.eval() + policy.save_pretrained(root) + loaded_policy = SACPolicy.from_pretrained(root, config=config) + loaded_policy.eval() + + batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10) + + with torch.no_grad(): + with seeded_context(12): + # Collect policy values before saving + cirtic_loss = policy.forward(batch, model="critic")["loss_critic"] + actor_loss = policy.forward(batch, model="actor")["loss_actor"] + temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"] + + observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + actions = policy.select_action(observation_batch) + + with seeded_context(12): + # Collect policy values after loading + loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"] + loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"] + loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"] + + loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim) + loaded_actions = loaded_policy.select_action(loaded_observation_batch) + + assert policy.state_dict().keys() == loaded_policy.state_dict().keys() + for k in policy.state_dict(): + assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6) + + # Compare values before and after saving and loading + # They should be the same + assert torch.allclose(cirtic_loss, loaded_cirtic_loss) + assert torch.allclose(actor_loss, loaded_actor_loss) + assert torch.allclose(temperature_loss, loaded_temperature_loss) + assert torch.allclose(actions, loaded_actions) diff --git a/tests/rl/test_actor.py b/tests/rl/test_actor.py new file mode 100644 index 000000000..0cf6a8f64 --- /dev/null +++ b/tests/rl/test_actor.py @@ -0,0 +1,208 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from concurrent import futures +from unittest.mock import patch + +import pytest +import torch +from torch.multiprocessing import Event, Queue + +from lerobot.common.utils.transition import Transition +from tests.utils import require_package + + +def create_learner_service_stub(): + import grpc + + from lerobot.common.transport import services_pb2, services_pb2_grpc + + class MockLearnerService(services_pb2_grpc.LearnerServiceServicer): + def __init__(self): + self.ready_call_count = 0 + self.should_fail = False + + def Ready(self, request, context): # noqa: N802 + self.ready_call_count += 1 + if self.should_fail: + context.set_code(grpc.StatusCode.UNAVAILABLE) + context.set_details("Service unavailable") + raise grpc.RpcError("Service unavailable") + return services_pb2.Empty() + + """Fixture to start a LearnerService gRPC server and provide a connected stub.""" + + servicer = MockLearnerService() + + # Create a gRPC server and add our servicer to it. + server = grpc.server(futures.ThreadPoolExecutor(max_workers=4)) + services_pb2_grpc.add_LearnerServiceServicer_to_server(servicer, server) + port = server.add_insecure_port("[::]:0") # bind to a free port chosen by OS + server.start() # start the server (non-blocking call):contentReference[oaicite:1]{index=1} + + # Create a client channel and stub connected to the server's port. + channel = grpc.insecure_channel(f"localhost:{port}") + return services_pb2_grpc.LearnerServiceStub(channel), servicer, channel, server + + +def close_service_stub(channel, server): + channel.close() + server.stop(None) + + +@require_package("grpc") +def test_establish_learner_connection_success(): + from lerobot.scripts.rl.actor import establish_learner_connection + + """Test successful connection establishment.""" + stub, _servicer, channel, server = create_learner_service_stub() + + shutdown_event = Event() + + # Test successful connection + result = establish_learner_connection(stub, shutdown_event, attempts=5) + + assert result is True + + close_service_stub(channel, server) + + +@require_package("grpc") +def test_establish_learner_connection_failure(): + from lerobot.scripts.rl.actor import establish_learner_connection + + """Test connection failure.""" + stub, servicer, channel, server = create_learner_service_stub() + servicer.should_fail = True + + shutdown_event = Event() + + # Test failed connection + with patch("time.sleep"): # Speed up the test + result = establish_learner_connection(stub, shutdown_event, attempts=2) + + assert result is False + + close_service_stub(channel, server) + + +@require_package("grpc") +def test_push_transitions_to_transport_queue(): + from lerobot.common.transport.utils import bytes_to_transitions + from lerobot.scripts.rl.actor import push_transitions_to_transport_queue + from tests.transport.test_transport_utils import assert_transitions_equal + + """Test pushing transitions to transport queue.""" + # Create mock transitions + transitions = [] + for i in range(3): + transition = Transition( + state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + action=torch.randn(5), + reward=torch.tensor(1.0 + i), + done=torch.tensor(False), + truncated=torch.tensor(False), + next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + complementary_info={"step": torch.tensor(i)}, + ) + transitions.append(transition) + + transitions_queue = Queue() + + # Test pushing transitions + push_transitions_to_transport_queue(transitions, transitions_queue) + + # Verify the data can be retrieved + serialized_data = transitions_queue.get() + assert isinstance(serialized_data, bytes) + deserialized_transitions = bytes_to_transitions(serialized_data) + assert len(deserialized_transitions) == len(transitions) + for i, deserialized_transition in enumerate(deserialized_transitions): + assert_transitions_equal(deserialized_transition, transitions[i]) + + +@require_package("grpc") +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_transitions_stream(): + from lerobot.scripts.rl.actor import transitions_stream + + """Test transitions stream functionality.""" + shutdown_event = Event() + transitions_queue = Queue() + + # Add test data to queue + test_data = [b"transition_data_1", b"transition_data_2", b"transition_data_3"] + for data in test_data: + transitions_queue.put(data) + + # Collect streamed data + streamed_data = [] + stream_generator = transitions_stream(shutdown_event, transitions_queue, 0.1) + + # Process a few items + for i, message in enumerate(stream_generator): + streamed_data.append(message) + if i >= len(test_data) - 1: + shutdown_event.set() + break + + # Verify we got messages + assert len(streamed_data) == len(test_data) + assert streamed_data[0].data == b"transition_data_1" + assert streamed_data[1].data == b"transition_data_2" + assert streamed_data[2].data == b"transition_data_3" + + +@require_package("grpc") +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_interactions_stream(): + from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes + from lerobot.scripts.rl.actor import interactions_stream + + """Test interactions stream functionality.""" + shutdown_event = Event() + interactions_queue = Queue() + + # Create test interaction data (similar structure to what would be sent) + test_interactions = [ + {"episode_reward": 10.5, "step": 1, "policy_fps": 30.2}, + {"episode_reward": 15.2, "step": 2, "policy_fps": 28.7}, + {"episode_reward": 8.7, "step": 3, "policy_fps": 29.1}, + ] + + # Serialize the interaction data as it would be in practice + test_data = [ + interactions_queue.put(python_object_to_bytes(interaction)) for interaction in test_interactions + ] + + # Collect streamed data + streamed_data = [] + stream_generator = interactions_stream(shutdown_event, interactions_queue, 0.1) + + # Process the items + for i, message in enumerate(stream_generator): + streamed_data.append(message) + if i >= len(test_data) - 1: + shutdown_event.set() + break + + # Verify we got messages + assert len(streamed_data) == len(test_data) + + # Verify the messages can be deserialized back to original data + for i, message in enumerate(streamed_data): + deserialized_interaction = bytes_to_python_object(message.data) + assert deserialized_interaction == test_interactions[i] diff --git a/tests/rl/test_actor_learner.py b/tests/rl/test_actor_learner.py new file mode 100644 index 000000000..cb72da7e4 --- /dev/null +++ b/tests/rl/test_actor_learner.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import socket +import threading +import time + +import pytest +import torch +from torch.multiprocessing import Event, Queue + +from lerobot.common.policies.sac.configuration_sac import SACConfig +from lerobot.common.utils.transition import Transition +from lerobot.configs.train import TrainRLServerPipelineConfig +from tests.utils import require_package + + +def create_test_transitions(count: int = 3) -> list[Transition]: + """Create test transitions for integration testing.""" + transitions = [] + for i in range(count): + transition = Transition( + state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + action=torch.randn(5), + reward=torch.tensor(1.0 + i), + done=torch.tensor(i == count - 1), # Last transition is done + truncated=torch.tensor(False), + next_state={"observation": torch.randn(3, 64, 64), "state": torch.randn(10)}, + complementary_info={"step": torch.tensor(i), "episode_id": i // 2}, + ) + transitions.append(transition) + return transitions + + +def create_test_interactions(count: int = 3) -> list[dict]: + """Create test interactions for integration testing.""" + interactions = [] + for i in range(count): + interaction = { + "episode_reward": 10.0 + i * 5, + "step": i * 100, + "policy_fps": 30.0 + i, + "intervention_rate": 0.1 * i, + "episode_length": 200 + i * 50, + } + interactions.append(interaction) + return interactions + + +def find_free_port(): + """Finds a free port on the local machine.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) # Bind to port 0 to let the OS choose a free port + s.listen(1) + port = s.getsockname()[1] + return port + + +@pytest.fixture +def cfg(): + cfg = TrainRLServerPipelineConfig() + + port = find_free_port() + + policy_cfg = SACConfig() + policy_cfg.actor_learner_config.learner_host = "127.0.0.1" + policy_cfg.actor_learner_config.learner_port = port + policy_cfg.concurrency.actor = "threads" + policy_cfg.concurrency.learner = "threads" + policy_cfg.actor_learner_config.queue_get_timeout = 0.1 + + cfg.policy = policy_cfg + + return cfg + + +@require_package("grpc") +@pytest.mark.timeout(10) # force cross-platform watchdog +def test_end_to_end_transitions_flow(cfg): + from lerobot.common.transport.utils import bytes_to_transitions + from lerobot.scripts.rl.actor import ( + establish_learner_connection, + learner_service_client, + push_transitions_to_transport_queue, + send_transitions, + ) + from lerobot.scripts.rl.learner import start_learner + from tests.transport.test_transport_utils import assert_transitions_equal + + """Test complete transitions flow from actor to learner.""" + transitions_actor_queue = Queue() + transitions_learner_queue = Queue() + + interactions_queue = Queue() + parameters_queue = Queue() + shutdown_event = Event() + + learner_thread = threading.Thread( + target=start_learner, + args=(parameters_queue, transitions_learner_queue, interactions_queue, shutdown_event, cfg), + ) + learner_thread.start() + + policy_cfg = cfg.policy + learner_client, channel = learner_service_client( + host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port + ) + + assert establish_learner_connection(learner_client, shutdown_event, attempts=5) + + send_transitions_thread = threading.Thread( + target=send_transitions, args=(cfg, transitions_actor_queue, shutdown_event, learner_client, channel) + ) + send_transitions_thread.start() + + input_transitions = create_test_transitions(count=5) + + push_transitions_to_transport_queue(input_transitions, transitions_actor_queue) + + # Wait for learner to start + time.sleep(0.1) + + shutdown_event.set() + + # Wait for learner to receive transitions + learner_thread.join() + send_transitions_thread.join() + channel.close() + + received_transitions = [] + while not transitions_learner_queue.empty(): + received_transitions.extend(bytes_to_transitions(transitions_learner_queue.get())) + + assert len(received_transitions) == len(input_transitions) + for i, transition in enumerate(received_transitions): + assert_transitions_equal(transition, input_transitions[i]) + + +@require_package("grpc") +@pytest.mark.timeout(10) +def test_end_to_end_interactions_flow(cfg): + from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes + from lerobot.scripts.rl.actor import ( + establish_learner_connection, + learner_service_client, + send_interactions, + ) + from lerobot.scripts.rl.learner import start_learner + + """Test complete interactions flow from actor to learner.""" + # Queues for actor-learner communication + interactions_actor_queue = Queue() + interactions_learner_queue = Queue() + + # Other queues required by the learner + parameters_queue = Queue() + transitions_learner_queue = Queue() + + shutdown_event = Event() + + # Start the learner in a separate thread + learner_thread = threading.Thread( + target=start_learner, + args=(parameters_queue, transitions_learner_queue, interactions_learner_queue, shutdown_event, cfg), + ) + learner_thread.start() + + # Establish connection from actor to learner + policy_cfg = cfg.policy + learner_client, channel = learner_service_client( + host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port + ) + + assert establish_learner_connection(learner_client, shutdown_event, attempts=5) + + # Start the actor's interaction sending process in a separate thread + send_interactions_thread = threading.Thread( + target=send_interactions, + args=(cfg, interactions_actor_queue, shutdown_event, learner_client, channel), + ) + send_interactions_thread.start() + + # Create and push test interactions to the actor's queue + input_interactions = create_test_interactions(count=5) + for interaction in input_interactions: + interactions_actor_queue.put(python_object_to_bytes(interaction)) + + # Wait for the communication to happen + time.sleep(0.1) + + # Signal shutdown and wait for threads to complete + shutdown_event.set() + learner_thread.join() + send_interactions_thread.join() + channel.close() + + # Verify that the learner received the interactions + received_interactions = [] + while not interactions_learner_queue.empty(): + received_interactions.append(bytes_to_python_object(interactions_learner_queue.get())) + + assert len(received_interactions) == len(input_interactions) + + # Sort by a unique key to handle potential reordering in queues + received_interactions.sort(key=lambda x: x["step"]) + input_interactions.sort(key=lambda x: x["step"]) + + for received, expected in zip(received_interactions, input_interactions, strict=False): + assert received == expected + + +@require_package("grpc") +@pytest.mark.parametrize("data_size", ["small", "large"]) +@pytest.mark.timeout(10) +def test_end_to_end_parameters_flow(cfg, data_size): + from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes + from lerobot.scripts.rl.actor import establish_learner_connection, learner_service_client, receive_policy + from lerobot.scripts.rl.learner import start_learner + + """Test complete parameter flow from learner to actor, with small and large data.""" + # Actor's local queue to receive params + parameters_actor_queue = Queue() + # Learner's queue to send params from + parameters_learner_queue = Queue() + + # Other queues required by the learner + transitions_learner_queue = Queue() + interactions_learner_queue = Queue() + + shutdown_event = Event() + + # Start the learner in a separate thread + learner_thread = threading.Thread( + target=start_learner, + args=( + parameters_learner_queue, + transitions_learner_queue, + interactions_learner_queue, + shutdown_event, + cfg, + ), + ) + learner_thread.start() + + # Establish connection from actor to learner + policy_cfg = cfg.policy + learner_client, channel = learner_service_client( + host=policy_cfg.actor_learner_config.learner_host, port=policy_cfg.actor_learner_config.learner_port + ) + + assert establish_learner_connection(learner_client, shutdown_event, attempts=5) + + # Start the actor's parameter receiving process in a separate thread + receive_params_thread = threading.Thread( + target=receive_policy, + args=(cfg, parameters_actor_queue, shutdown_event, learner_client, channel), + ) + receive_params_thread.start() + + # Create test parameters based on parametrization + if data_size == "small": + input_params = {"layer.weight": torch.randn(128, 64)} + else: # "large" + # CHUNK_SIZE is 2MB, so this tensor (4MB) will force chunking + input_params = {"large_layer.weight": torch.randn(1024, 1024)} + + # Simulate learner having new parameters to send + parameters_learner_queue.put(state_to_bytes(input_params)) + + # Wait for the actor to receive the parameters + time.sleep(0.1) + + # Signal shutdown and wait for threads to complete + shutdown_event.set() + learner_thread.join() + receive_params_thread.join() + channel.close() + + # Verify that the actor received the parameters correctly + received_params = bytes_to_state_dict(parameters_actor_queue.get()) + + assert received_params.keys() == input_params.keys() + for key in input_params: + assert torch.allclose(received_params[key], input_params[key]) diff --git a/tests/rl/test_learner_service.py b/tests/rl/test_learner_service.py new file mode 100644 index 000000000..ee9d06e91 --- /dev/null +++ b/tests/rl/test_learner_service.py @@ -0,0 +1,374 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import threading +import time +from concurrent import futures +from multiprocessing import Event, Queue + +import pytest + +from tests.utils import require_package # our gRPC servicer class + + +@pytest.fixture(scope="function") +def learner_service_stub(): + shutdown_event = Event() + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + seconds_between_pushes = 1 + client, channel, server = create_learner_service_stub( + shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes + ) + + yield client # provide the stub to the test function + + close_learner_service_stub(channel, server) + + +@require_package("grpc") +def create_learner_service_stub( + shutdown_event: Event, + parameters_queue: Queue, + transitions_queue: Queue, + interactions_queue: Queue, + seconds_between_pushes: int, + queue_get_timeout: float = 0.1, +): + import grpc + + from lerobot.common.transport import services_pb2_grpc # generated from .proto + from lerobot.scripts.rl.learner_service import LearnerService + + """Fixture to start a LearnerService gRPC server and provide a connected stub.""" + + servicer = LearnerService( + shutdown_event=shutdown_event, + parameters_queue=parameters_queue, + seconds_between_pushes=seconds_between_pushes, + transition_queue=transitions_queue, + interaction_message_queue=interactions_queue, + queue_get_timeout=queue_get_timeout, + ) + + # Create a gRPC server and add our servicer to it. + server = grpc.server(futures.ThreadPoolExecutor(max_workers=4)) + services_pb2_grpc.add_LearnerServiceServicer_to_server(servicer, server) + port = server.add_insecure_port("[::]:0") # bind to a free port chosen by OS + server.start() # start the server (non-blocking call):contentReference[oaicite:1]{index=1} + + # Create a client channel and stub connected to the server's port. + channel = grpc.insecure_channel(f"localhost:{port}") + return services_pb2_grpc.LearnerServiceStub(channel), channel, server + + +@require_package("grpc") +def close_learner_service_stub(channel, server): + channel.close() + server.stop(None) + + +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_ready_method(learner_service_stub): + from lerobot.common.transport import services_pb2 + + """Test the ready method of the UserService.""" + request = services_pb2.Empty() + response = learner_service_stub.Ready(request) + assert response == services_pb2.Empty() + + +@require_package("grpc") +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_send_interactions(): + from lerobot.common.transport import services_pb2 + + shutdown_event = Event() + + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + seconds_between_pushes = 1 + client, channel, server = create_learner_service_stub( + shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes + ) + + list_of_interaction_messages = [ + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"1"), + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"2"), + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"3"), + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"4"), + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"5"), + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"6"), + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"7"), + services_pb2.InteractionMessage(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"8"), + ] + + def mock_intercations_stream(): + yield from list_of_interaction_messages + + return services_pb2.Empty() + + response = client.SendInteractions(mock_intercations_stream()) + assert response == services_pb2.Empty() + + close_learner_service_stub(channel, server) + + # Extract the data from the interactions queue + interactions = [] + while not interactions_queue.empty(): + interactions.append(interactions_queue.get()) + + assert interactions == [b"123", b"4", b"5", b"678"] + + +@require_package("grpc") +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_send_transitions(): + from lerobot.common.transport import services_pb2 + + """Test the SendTransitions method with various transition data.""" + shutdown_event = Event() + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + seconds_between_pushes = 1 + + client, channel, server = create_learner_service_stub( + shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes + ) + + # Create test transition messages + list_of_transition_messages = [ + services_pb2.Transition( + transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"transition_1" + ), + services_pb2.Transition( + transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE, data=b"transition_2" + ), + services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"transition_3"), + services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_BEGIN, data=b"batch_1"), + services_pb2.Transition(transfer_state=services_pb2.TransferState.TRANSFER_END, data=b"batch_2"), + ] + + def mock_transitions_stream(): + yield from list_of_transition_messages + + response = client.SendTransitions(mock_transitions_stream()) + assert response == services_pb2.Empty() + + close_learner_service_stub(channel, server) + + # Extract the data from the transitions queue + transitions = [] + while not transitions_queue.empty(): + transitions.append(transitions_queue.get()) + + # Should have assembled the chunked data + assert transitions == [b"transition_1transition_2transition_3", b"batch_1batch_2"] + + +@require_package("grpc") +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_send_transitions_empty_stream(): + from lerobot.common.transport import services_pb2 + + """Test SendTransitions with empty stream.""" + shutdown_event = Event() + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + seconds_between_pushes = 1 + + client, channel, server = create_learner_service_stub( + shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes + ) + + def empty_stream(): + return iter([]) + + response = client.SendTransitions(empty_stream()) + assert response == services_pb2.Empty() + + close_learner_service_stub(channel, server) + + # Queue should remain empty + assert transitions_queue.empty() + + +@require_package("grpc") +@pytest.mark.timeout(10) # force cross-platform watchdog +def test_stream_parameters(): + import time + + from lerobot.common.transport import services_pb2 + + """Test the StreamParameters method.""" + shutdown_event = Event() + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + seconds_between_pushes = 0.2 # Short delay for testing + + client, channel, server = create_learner_service_stub( + shutdown_event, parameters_queue, transitions_queue, interactions_queue, seconds_between_pushes + ) + + # Add test parameters to the queue + test_params = [b"param_batch_1", b"param_batch_2"] + for param in test_params: + parameters_queue.put(param) + + # Start streaming parameters + request = services_pb2.Empty() + stream = client.StreamParameters(request) + + # Collect streamed parameters and timestamps + received_params = [] + timestamps = [] + + for response in stream: + received_params.append(response.data) + timestamps.append(time.time()) + + # We should receive one last item + break + + parameters_queue.put(b"param_batch_3") + + for response in stream: + received_params.append(response.data) + timestamps.append(time.time()) + + # We should receive only one item + break + + shutdown_event.set() + close_learner_service_stub(channel, server) + + assert received_params == [b"param_batch_2", b"param_batch_3"] + + # Check the time difference between the two sends + time_diff = timestamps[1] - timestamps[0] + # Check if the time difference is close to the expected push frequency + assert time_diff == pytest.approx(seconds_between_pushes, abs=0.1) + + +@require_package("grpc") +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_stream_parameters_with_shutdown(): + from lerobot.common.transport import services_pb2 + + """Test StreamParameters handles shutdown gracefully.""" + shutdown_event = Event() + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + seconds_between_pushes = 0.1 + queue_get_timeout = 0.001 + + client, channel, server = create_learner_service_stub( + shutdown_event, + parameters_queue, + transitions_queue, + interactions_queue, + seconds_between_pushes, + queue_get_timeout=queue_get_timeout, + ) + + test_params = [b"param_batch_1", b"stop", b"param_batch_3", b"param_batch_4"] + + # create a thread that will put the parameters in the queue + def producer(): + for param in test_params: + parameters_queue.put(param) + time.sleep(0.1) + + producer_thread = threading.Thread(target=producer) + producer_thread.start() + + # Start streaming + request = services_pb2.Empty() + stream = client.StreamParameters(request) + + # Collect streamed parameters + received_params = [] + + for response in stream: + received_params.append(response.data) + + if response.data == b"stop": + shutdown_event.set() + + producer_thread.join() + close_learner_service_stub(channel, server) + + assert received_params == [b"param_batch_1", b"stop"] + + +@require_package("grpc") +@pytest.mark.timeout(3) # force cross-platform watchdog +def test_stream_parameters_waits_and_retries_on_empty_queue(): + import threading + import time + + from lerobot.common.transport import services_pb2 + + """Test that StreamParameters waits and retries when the queue is empty.""" + shutdown_event = Event() + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + seconds_between_pushes = 0.05 + queue_get_timeout = 0.01 + + client, channel, server = create_learner_service_stub( + shutdown_event, + parameters_queue, + transitions_queue, + interactions_queue, + seconds_between_pushes, + queue_get_timeout=queue_get_timeout, + ) + + request = services_pb2.Empty() + stream = client.StreamParameters(request) + + received_params = [] + + def producer(): + # Let the consumer start and find an empty queue. + # It will wait `seconds_between_pushes` (0.05s), then `get` will timeout after `queue_get_timeout` (0.01s). + # Total time for the first empty loop is > 0.06s. We wait a bit longer to be safe. + time.sleep(0.06) + parameters_queue.put(b"param_after_wait") + time.sleep(0.05) + parameters_queue.put(b"param_after_wait_2") + + producer_thread = threading.Thread(target=producer) + producer_thread.start() + + # The consumer will block here until the producer sends an item. + for response in stream: + received_params.append(response.data) + if response.data == b"param_after_wait_2": + break # We only need one item for this test. + + shutdown_event.set() + producer_thread.join() + close_learner_service_stub(channel, server) + + assert received_params == [b"param_after_wait", b"param_after_wait_2"] diff --git a/tests/robots/test_robots.py b/tests/robots/test_robots.py deleted file mode 100644 index 71343eba9..000000000 --- a/tests/robots/test_robots.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Tests for physical robots and their mocked versions. -If the physical robots are not connected to the computer, or not working, -the test will be skipped. - -Example of running a specific test: -```bash -pytest -sx tests/test_robots.py::test_robot -``` - -Example of running test on real robots connected to the computer: -```bash -pytest -sx 'tests/test_robots.py::test_robot[koch-False]' -pytest -sx 'tests/test_robots.py::test_robot[koch_bimanual-False]' -pytest -sx 'tests/test_robots.py::test_robot[aloha-False]' -``` - -Example of running test on a mocked version of robots: -```bash -pytest -sx 'tests/test_robots.py::test_robot[koch-True]' -pytest -sx 'tests/test_robots.py::test_robot[koch_bimanual-True]' -pytest -sx 'tests/test_robots.py::test_robot[aloha-True]' -``` -""" - -import pytest -import torch - -from lerobot.common.robot_devices.robots.utils import make_robot -from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError -from tests.utils import TEST_ROBOT_TYPES, mock_calibration_dir, require_robot - - -@pytest.mark.parametrize("robot_type, mock", TEST_ROBOT_TYPES) -@require_robot -def test_robot(tmp_path, request, robot_type, mock): - # TODO(rcadene): measure fps in nightly? - # TODO(rcadene): test logs - # TODO(rcadene): add compatibility with other robots - robot_kwargs = {"robot_type": robot_type, "mock": mock} - - if robot_type == "aloha" and mock: - # To simplify unit test, we do not rerun manual calibration for Aloha mock=True. - # Instead, we use the files from '.cache/calibration/aloha_default' - pass - else: - if mock: - request.getfixturevalue("patch_builtins_input") - - # Create an empty calibration directory to trigger manual calibration - calibration_dir = tmp_path / robot_type - mock_calibration_dir(calibration_dir) - robot_kwargs["calibration_dir"] = calibration_dir - - # Test using robot before connecting raises an error - robot = make_robot(**robot_kwargs) - with pytest.raises(RobotDeviceNotConnectedError): - robot.teleop_step() - with pytest.raises(RobotDeviceNotConnectedError): - robot.teleop_step(record_data=True) - with pytest.raises(RobotDeviceNotConnectedError): - robot.capture_observation() - with pytest.raises(RobotDeviceNotConnectedError): - robot.send_action(None) - with pytest.raises(RobotDeviceNotConnectedError): - robot.disconnect() - - # Test deleting the object without connecting first - del robot - - # Test connecting (triggers manual calibration) - robot = make_robot(**robot_kwargs) - robot.connect() - assert robot.is_connected - - # Test connecting twice raises an error - with pytest.raises(RobotDeviceAlreadyConnectedError): - robot.connect() - - # TODO(rcadene, aliberts): Test disconnecting with `__del__` instead of `disconnect` - # del robot - robot.disconnect() - - # Test teleop can run - robot = make_robot(**robot_kwargs) - robot.connect() - robot.teleop_step() - - # Test data recorded during teleop are well formatted - observation, action = robot.teleop_step(record_data=True) - # State - assert "observation.state" in observation - assert isinstance(observation["observation.state"], torch.Tensor) - assert observation["observation.state"].ndim == 1 - dim_state = sum(len(robot.follower_arms[name].motors) for name in robot.follower_arms) - assert observation["observation.state"].shape[0] == dim_state - # Cameras - for name in robot.cameras: - assert f"observation.images.{name}" in observation - assert isinstance(observation[f"observation.images.{name}"], torch.Tensor) - assert observation[f"observation.images.{name}"].ndim == 3 - # Action - assert "action" in action - assert isinstance(action["action"], torch.Tensor) - assert action["action"].ndim == 1 - dim_action = sum(len(robot.follower_arms[name].motors) for name in robot.follower_arms) - assert action["action"].shape[0] == dim_action - # TODO(rcadene): test if observation and action data are returned as expected - - # Test capture_observation can run and observation returned are the same (since the arm didnt move) - captured_observation = robot.capture_observation() - assert set(captured_observation.keys()) == set(observation.keys()) - for name in captured_observation: - if "image" in name: - # TODO(rcadene): skipping image for now as it's challenging to assess equality between two consecutive frames - continue - torch.testing.assert_close(captured_observation[name], observation[name], rtol=1e-4, atol=1) - assert captured_observation[name].shape == observation[name].shape - - # Test send_action can run - robot.send_action(action["action"]) - - # Test disconnecting - robot.disconnect() - assert not robot.is_connected - for name in robot.follower_arms: - assert not robot.follower_arms[name].is_connected - for name in robot.leader_arms: - assert not robot.leader_arms[name].is_connected - for name in robot.cameras: - assert not robot.cameras[name].is_connected diff --git a/tests/robots/test_so100_follower.py b/tests/robots/test_so100_follower.py new file mode 100644 index 000000000..81d9d6a91 --- /dev/null +++ b/tests/robots/test_so100_follower.py @@ -0,0 +1,95 @@ +from contextlib import contextmanager +from unittest.mock import MagicMock, patch + +import pytest + +from lerobot.common.robots.so100_follower import ( + SO100Follower, + SO100FollowerConfig, +) + + +def _make_bus_mock() -> MagicMock: + """Return a bus mock with just the attributes used by the robot.""" + bus = MagicMock(name="FeetechBusMock") + bus.is_connected = False + + def _connect(): + bus.is_connected = True + + def _disconnect(_disable=True): + bus.is_connected = False + + bus.connect.side_effect = _connect + bus.disconnect.side_effect = _disconnect + + @contextmanager + def _dummy_cm(): + yield + + bus.torque_disabled.side_effect = _dummy_cm + + return bus + + +@pytest.fixture +def follower(): + bus_mock = _make_bus_mock() + + def _bus_side_effect(*_args, **kwargs): + bus_mock.motors = kwargs["motors"] + motors_order: list[str] = list(bus_mock.motors) + + bus_mock.sync_read.return_value = {motor: idx for idx, motor in enumerate(motors_order, 1)} + bus_mock.sync_write.return_value = None + bus_mock.write.return_value = None + bus_mock.disable_torque.return_value = None + bus_mock.enable_torque.return_value = None + bus_mock.is_calibrated = True + return bus_mock + + with ( + patch( + "lerobot.common.robots.so100_follower.so100_follower.FeetechMotorsBus", + side_effect=_bus_side_effect, + ), + patch.object(SO100Follower, "configure", lambda self: None), + ): + cfg = SO100FollowerConfig(port="/dev/null") + robot = SO100Follower(cfg) + yield robot + if robot.is_connected: + robot.disconnect() + + +def test_connect_disconnect(follower): + assert not follower.is_connected + + follower.connect() + assert follower.is_connected + + follower.disconnect() + assert not follower.is_connected + + +def test_get_observation(follower): + follower.connect() + obs = follower.get_observation() + + expected_keys = {f"{m}.pos" for m in follower.bus.motors} + assert set(obs.keys()) == expected_keys + + for idx, motor in enumerate(follower.bus.motors, 1): + assert obs[f"{motor}.pos"] == idx + + +def test_send_action(follower): + follower.connect() + + action = {f"{m}.pos": i * 10 for i, m in enumerate(follower.bus.motors, 1)} + returned = follower.send_action(action) + + assert returned == action + + goal_pos = {m: (i + 1) * 10 for i, m in enumerate(follower.bus.motors)} + follower.bus.sync_write.assert_called_once_with("Goal_Position", goal_pos) diff --git a/tests/test_available.py b/tests/test_available.py index f4f9d4de6..a18b95ffa 100644 --- a/tests/test_available.py +++ b/tests/test_available.py @@ -45,12 +45,7 @@ def test_available_policies(): This test verifies that the class attribute `name` for all policies is consistent with those listed in `lerobot/__init__.py`. """ - policy_classes = [ - ACTPolicy, - DiffusionPolicy, - TDMPCPolicy, - VQBeTPolicy, - ] + policy_classes = [ACTPolicy, DiffusionPolicy, TDMPCPolicy, VQBeTPolicy] policies = [pol_cls.name for pol_cls in policy_classes] assert set(policies) == set(lerobot.available_policies), policies diff --git a/tests/transport/test_transport_utils.py b/tests/transport/test_transport_utils.py new file mode 100644 index 000000000..cf33f52c0 --- /dev/null +++ b/tests/transport/test_transport_utils.py @@ -0,0 +1,571 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +from multiprocessing import Event, Queue +from pickle import UnpicklingError + +import pytest +import torch + +from lerobot.common.utils.transition import Transition +from tests.utils import require_cuda, require_package + + +@require_package("grpc") +def test_bytes_buffer_size_empty_buffer(): + from lerobot.common.transport.utils import bytes_buffer_size + + """Test with an empty buffer.""" + buffer = io.BytesIO() + assert bytes_buffer_size(buffer) == 0 + # Ensure position is reset to beginning + assert buffer.tell() == 0 + + +@require_package("grpc") +def test_bytes_buffer_size_small_buffer(): + from lerobot.common.transport.utils import bytes_buffer_size + + """Test with a small buffer.""" + buffer = io.BytesIO(b"Hello, World!") + assert bytes_buffer_size(buffer) == 13 + assert buffer.tell() == 0 + + +@require_package("grpc") +def test_bytes_buffer_size_large_buffer(): + from lerobot.common.transport.utils import CHUNK_SIZE, bytes_buffer_size + + """Test with a large buffer.""" + data = b"x" * (CHUNK_SIZE * 2 + 1000) + buffer = io.BytesIO(data) + assert bytes_buffer_size(buffer) == len(data) + assert buffer.tell() == 0 + + +@require_package("grpc") +def test_send_bytes_in_chunks_empty_data(): + from lerobot.common.transport.utils import send_bytes_in_chunks, services_pb2 + + """Test sending empty data.""" + message_class = services_pb2.InteractionMessage + chunks = list(send_bytes_in_chunks(b"", message_class)) + assert len(chunks) == 0 + + +@require_package("grpc") +def test_single_chunk_small_data(): + from lerobot.common.transport.utils import send_bytes_in_chunks, services_pb2 + + """Test data that fits in a single chunk.""" + data = b"Some data" + message_class = services_pb2.InteractionMessage + chunks = list(send_bytes_in_chunks(data, message_class)) + + assert len(chunks) == 1 + assert chunks[0].data == b"Some data" + assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END + + +@require_package("grpc") +def test_not_silent_mode(): + from lerobot.common.transport.utils import send_bytes_in_chunks, services_pb2 + + """Test not silent mode.""" + data = b"Some data" + message_class = services_pb2.InteractionMessage + chunks = list(send_bytes_in_chunks(data, message_class, silent=False)) + assert len(chunks) == 1 + assert chunks[0].data == b"Some data" + + +@require_package("grpc") +def test_send_bytes_in_chunks_large_data(): + from lerobot.common.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2 + + """Test sending large data.""" + data = b"x" * (CHUNK_SIZE * 2 + 1000) + message_class = services_pb2.InteractionMessage + chunks = list(send_bytes_in_chunks(data, message_class)) + assert len(chunks) == 3 + assert chunks[0].data == b"x" * CHUNK_SIZE + assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_BEGIN + assert chunks[1].data == b"x" * CHUNK_SIZE + assert chunks[1].transfer_state == services_pb2.TransferState.TRANSFER_MIDDLE + assert chunks[2].data == b"x" * 1000 + assert chunks[2].transfer_state == services_pb2.TransferState.TRANSFER_END + + +@require_package("grpc") +def test_send_bytes_in_chunks_large_data_with_exact_chunk_size(): + from lerobot.common.transport.utils import CHUNK_SIZE, send_bytes_in_chunks, services_pb2 + + """Test sending large data with exact chunk size.""" + data = b"x" * CHUNK_SIZE + message_class = services_pb2.InteractionMessage + chunks = list(send_bytes_in_chunks(data, message_class)) + assert len(chunks) == 1 + assert chunks[0].data == data + assert chunks[0].transfer_state == services_pb2.TransferState.TRANSFER_END + + +@require_package("grpc") +def test_receive_bytes_in_chunks_empty_data(): + from lerobot.common.transport.utils import receive_bytes_in_chunks + + """Test receiving empty data.""" + queue = Queue() + shutdown_event = Event() + + # Empty iterator + receive_bytes_in_chunks(iter([]), queue, shutdown_event) + + assert queue.empty() + + +@require_package("grpc") +def test_receive_bytes_in_chunks_single_chunk(): + from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2 + + """Test receiving a single chunk message.""" + queue = Queue() + shutdown_event = Event() + + data = b"Single chunk data" + chunks = [ + services_pb2.InteractionMessage(data=data, transfer_state=services_pb2.TransferState.TRANSFER_END) + ] + + receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) + + assert queue.get(timeout=0.01) == data + assert queue.empty() + + +@require_package("grpc") +def test_receive_bytes_in_chunks_single_not_end_chunk(): + from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2 + + """Test receiving a single chunk message.""" + queue = Queue() + shutdown_event = Event() + + data = b"Single chunk data" + chunks = [ + services_pb2.InteractionMessage(data=data, transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE) + ] + + receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) + + assert queue.empty() + + +@require_package("grpc") +def test_receive_bytes_in_chunks_multiple_chunks(): + from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2 + + """Test receiving a multi-chunk message.""" + queue = Queue() + shutdown_event = Event() + + chunks = [ + services_pb2.InteractionMessage( + data=b"First ", transfer_state=services_pb2.TransferState.TRANSFER_BEGIN + ), + services_pb2.InteractionMessage( + data=b"Middle ", transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE + ), + services_pb2.InteractionMessage(data=b"Last", transfer_state=services_pb2.TransferState.TRANSFER_END), + ] + + receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) + + assert queue.get(timeout=0.01) == b"First Middle Last" + assert queue.empty() + + +@require_package("grpc") +def test_receive_bytes_in_chunks_multiple_messages(): + from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2 + + """Test receiving multiple complete messages in sequence.""" + queue = Queue() + shutdown_event = Event() + + chunks = [ + # First message - single chunk + services_pb2.InteractionMessage( + data=b"Message1", transfer_state=services_pb2.TransferState.TRANSFER_END + ), + # Second message - multi chunk + services_pb2.InteractionMessage( + data=b"Start2 ", transfer_state=services_pb2.TransferState.TRANSFER_BEGIN + ), + services_pb2.InteractionMessage( + data=b"Middle2 ", transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE + ), + services_pb2.InteractionMessage(data=b"End2", transfer_state=services_pb2.TransferState.TRANSFER_END), + # Third message - single chunk + services_pb2.InteractionMessage( + data=b"Message3", transfer_state=services_pb2.TransferState.TRANSFER_END + ), + ] + + receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) + + # Should have three messages in queue + assert queue.get(timeout=0.01) == b"Message1" + assert queue.get(timeout=0.01) == b"Start2 Middle2 End2" + assert queue.get(timeout=0.01) == b"Message3" + assert queue.empty() + + +@require_package("grpc") +def test_receive_bytes_in_chunks_shutdown_during_receive(): + from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2 + + """Test that shutdown event stops receiving mid-stream.""" + queue = Queue() + shutdown_event = Event() + shutdown_event.set() + + chunks = [ + services_pb2.InteractionMessage( + data=b"First ", transfer_state=services_pb2.TransferState.TRANSFER_BEGIN + ), + services_pb2.InteractionMessage( + data=b"Middle ", transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE + ), + services_pb2.InteractionMessage(data=b"Last", transfer_state=services_pb2.TransferState.TRANSFER_END), + ] + + receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) + + assert queue.empty() + + +@require_package("grpc") +def test_receive_bytes_in_chunks_only_begin_chunk(): + from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2 + + """Test receiving only a BEGIN chunk without END.""" + queue = Queue() + shutdown_event = Event() + + chunks = [ + services_pb2.InteractionMessage( + data=b"Start", transfer_state=services_pb2.TransferState.TRANSFER_BEGIN + ), + # No END chunk + ] + + receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) + + assert queue.empty() + + +@require_package("grpc") +def test_receive_bytes_in_chunks_missing_begin(): + from lerobot.common.transport.utils import receive_bytes_in_chunks, services_pb2 + + """Test receiving chunks starting with MIDDLE instead of BEGIN.""" + queue = Queue() + shutdown_event = Event() + + chunks = [ + # Missing BEGIN + services_pb2.InteractionMessage( + data=b"Middle", transfer_state=services_pb2.TransferState.TRANSFER_MIDDLE + ), + services_pb2.InteractionMessage(data=b"End", transfer_state=services_pb2.TransferState.TRANSFER_END), + ] + + receive_bytes_in_chunks(iter(chunks), queue, shutdown_event) + + # The implementation continues from where it is, so we should get partial data + assert queue.get(timeout=0.01) == b"MiddleEnd" + assert queue.empty() + + +# Tests for state_to_bytes and bytes_to_state_dict +@require_package("grpc") +def test_state_to_bytes_empty_dict(): + from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes + + """Test converting empty state dict to bytes.""" + state_dict = {} + data = state_to_bytes(state_dict) + reconstructed = bytes_to_state_dict(data) + assert reconstructed == state_dict + + +@require_package("grpc") +def test_bytes_to_state_dict_empty_data(): + from lerobot.common.transport.utils import bytes_to_state_dict + + """Test converting empty data to state dict.""" + with pytest.raises(EOFError): + bytes_to_state_dict(b"") + + +@require_package("grpc") +def test_state_to_bytes_simple_dict(): + from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes + + """Test converting simple state dict to bytes.""" + state_dict = { + "layer1.weight": torch.randn(10, 5), + "layer1.bias": torch.randn(10), + "layer2.weight": torch.randn(1, 10), + "layer2.bias": torch.randn(1), + } + + data = state_to_bytes(state_dict) + assert isinstance(data, bytes) + assert len(data) > 0 + + reconstructed = bytes_to_state_dict(data) + + assert len(reconstructed) == len(state_dict) + for key in state_dict: + assert key in reconstructed + assert torch.allclose(state_dict[key], reconstructed[key]) + + +@require_package("grpc") +def test_state_to_bytes_various_dtypes(): + from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes + + """Test converting state dict with various tensor dtypes.""" + state_dict = { + "float32": torch.randn(5, 5), + "float64": torch.randn(3, 3).double(), + "int32": torch.randint(0, 100, (4, 4), dtype=torch.int32), + "int64": torch.randint(0, 100, (2, 2), dtype=torch.int64), + "bool": torch.tensor([True, False, True]), + "uint8": torch.randint(0, 255, (3, 3), dtype=torch.uint8), + } + + data = state_to_bytes(state_dict) + reconstructed = bytes_to_state_dict(data) + + for key in state_dict: + assert reconstructed[key].dtype == state_dict[key].dtype + if state_dict[key].dtype == torch.bool: + assert torch.equal(state_dict[key], reconstructed[key]) + else: + assert torch.allclose(state_dict[key], reconstructed[key]) + + +@require_package("grpc") +def test_bytes_to_state_dict_invalid_data(): + from lerobot.common.transport.utils import bytes_to_state_dict + + """Test bytes_to_state_dict with invalid data.""" + with pytest.raises(UnpicklingError): + bytes_to_state_dict(b"This is not a valid torch save file") + + +@require_cuda +@require_package("grpc") +def test_state_to_bytes_various_dtypes_cuda(): + from lerobot.common.transport.utils import bytes_to_state_dict, state_to_bytes + + """Test converting state dict with various tensor dtypes.""" + state_dict = { + "float32": torch.randn(5, 5).cuda(), + "float64": torch.randn(3, 3).double().cuda(), + "int32": torch.randint(0, 100, (4, 4), dtype=torch.int32).cuda(), + "int64": torch.randint(0, 100, (2, 2), dtype=torch.int64).cuda(), + "bool": torch.tensor([True, False, True]), + "uint8": torch.randint(0, 255, (3, 3), dtype=torch.uint8), + } + + data = state_to_bytes(state_dict) + reconstructed = bytes_to_state_dict(data) + + for key in state_dict: + assert reconstructed[key].dtype == state_dict[key].dtype + if state_dict[key].dtype == torch.bool: + assert torch.equal(state_dict[key], reconstructed[key]) + else: + assert torch.allclose(state_dict[key], reconstructed[key]) + + +@require_package("grpc") +def test_python_object_to_bytes_none(): + from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes + + """Test converting None to bytes.""" + obj = None + data = python_object_to_bytes(obj) + reconstructed = bytes_to_python_object(data) + assert reconstructed is None + + +@pytest.mark.parametrize( + "obj", + [ + 42, + -123, + 3.14159, + -2.71828, + "Hello, World!", + "Unicode: 你好世界 🌍", + True, + False, + b"byte string", + [], + [1, 2, 3], + [1, "two", 3.0, True, None], + {}, + {"key": "value", "number": 123, "nested": {"a": 1}}, + (), + (1, 2, 3), + ], +) +@require_package("grpc") +def test_python_object_to_bytes_simple_types(obj): + from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes + + """Test converting simple Python types.""" + data = python_object_to_bytes(obj) + reconstructed = bytes_to_python_object(data) + assert reconstructed == obj + assert type(reconstructed) is type(obj) + + +@require_package("grpc") +def test_python_object_to_bytes_with_tensors(): + from lerobot.common.transport.utils import bytes_to_python_object, python_object_to_bytes + + """Test converting objects containing PyTorch tensors.""" + obj = { + "tensor": torch.randn(5, 5), + "list_with_tensor": [1, 2, torch.randn(3, 3), "string"], + "nested": { + "tensor1": torch.randn(2, 2), + "tensor2": torch.tensor([1, 2, 3]), + }, + } + + data = python_object_to_bytes(obj) + reconstructed = bytes_to_python_object(data) + + assert torch.allclose(obj["tensor"], reconstructed["tensor"]) + assert reconstructed["list_with_tensor"][0] == 1 + assert reconstructed["list_with_tensor"][3] == "string" + assert torch.allclose(obj["list_with_tensor"][2], reconstructed["list_with_tensor"][2]) + assert torch.allclose(obj["nested"]["tensor1"], reconstructed["nested"]["tensor1"]) + assert torch.equal(obj["nested"]["tensor2"], reconstructed["nested"]["tensor2"]) + + +@require_package("grpc") +def test_transitions_to_bytes_empty_list(): + from lerobot.common.transport.utils import bytes_to_transitions, transitions_to_bytes + + """Test converting empty transitions list.""" + transitions = [] + data = transitions_to_bytes(transitions) + reconstructed = bytes_to_transitions(data) + assert reconstructed == transitions + assert isinstance(reconstructed, list) + + +@require_package("grpc") +def test_transitions_to_bytes_single_transition(): + from lerobot.common.transport.utils import bytes_to_transitions, transitions_to_bytes + + """Test converting a single transition.""" + transition = Transition( + state={"image": torch.randn(3, 64, 64), "state": torch.randn(10)}, + action=torch.randn(5), + reward=torch.tensor(1.5), + done=torch.tensor(False), + next_state={"image": torch.randn(3, 64, 64), "state": torch.randn(10)}, + ) + + transitions = [transition] + data = transitions_to_bytes(transitions) + reconstructed = bytes_to_transitions(data) + + assert len(reconstructed) == 1 + + assert_transitions_equal(transitions[0], reconstructed[0]) + + +@require_package("grpc") +def assert_transitions_equal(t1: Transition, t2: Transition): + """Helper to assert two transitions are equal.""" + assert_observation_equal(t1["state"], t2["state"]) + assert torch.allclose(t1["action"], t2["action"]) + assert torch.allclose(t1["reward"], t2["reward"]) + assert torch.equal(t1["done"], t2["done"]) + assert_observation_equal(t1["next_state"], t2["next_state"]) + + +@require_package("grpc") +def assert_observation_equal(o1: dict, o2: dict): + """Helper to assert two observations are equal.""" + assert set(o1.keys()) == set(o2.keys()) + for key in o1: + assert torch.allclose(o1[key], o2[key]) + + +@require_package("grpc") +def test_transitions_to_bytes_multiple_transitions(): + from lerobot.common.transport.utils import bytes_to_transitions, transitions_to_bytes + + """Test converting multiple transitions.""" + transitions = [] + for i in range(5): + transition = Transition( + state={"data": torch.randn(10)}, + action=torch.randn(3), + reward=torch.tensor(float(i)), + done=torch.tensor(i == 4), + next_state={"data": torch.randn(10)}, + ) + transitions.append(transition) + + data = transitions_to_bytes(transitions) + reconstructed = bytes_to_transitions(data) + + assert len(reconstructed) == len(transitions) + for original, reconstructed_item in zip(transitions, reconstructed, strict=False): + assert_transitions_equal(original, reconstructed_item) + + +@require_package("grpc") +def test_receive_bytes_in_chunks_unknown_state(): + from lerobot.common.transport.utils import receive_bytes_in_chunks + + """Test receive_bytes_in_chunks with an unknown transfer state.""" + + # Mock the gRPC message object, which has `transfer_state` and `data` attributes. + class MockMessage: + def __init__(self, transfer_state, data): + self.transfer_state = transfer_state + self.data = data + + # 10 is not a valid TransferState enum value + bad_iterator = [MockMessage(transfer_state=10, data=b"bad_data")] + output_queue = Queue() + shutdown_event = Event() + + with pytest.raises(ValueError, match="Received unknown transfer state"): + receive_bytes_in_chunks(bad_iterator, output_queue, shutdown_event) diff --git a/tests/utils.py b/tests/utils.py index c49b5b9ff..103b973fb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -13,20 +13,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import os import platform from functools import wraps -from pathlib import Path import pytest import torch from lerobot import available_cameras, available_motors, available_robots -from lerobot.common.robot_devices.cameras.utils import Camera -from lerobot.common.robot_devices.cameras.utils import make_camera as make_camera_device -from lerobot.common.robot_devices.motors.utils import MotorsBus -from lerobot.common.robot_devices.motors.utils import make_motors_bus as make_motors_bus_device from lerobot.common.utils.import_utils import is_package_available DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", "cuda") if torch.cuda.is_available() else "cpu" @@ -188,144 +182,3 @@ def require_package(package_name): return wrapper return decorator - - -def require_robot(func): - """ - Decorator that skips the test if a robot is not available - - The decorated function must have two arguments `request` and `robot_type`. - - Example of usage: - ```python - @pytest.mark.parametrize( - "robot_type", ["koch", "aloha"] - ) - @require_robot - def test_require_robot(request, robot_type): - pass - ``` - """ - - @wraps(func) - def wrapper(*args, **kwargs): - # Access the pytest request context to get the is_robot_available fixture - request = kwargs.get("request") - robot_type = kwargs.get("robot_type") - mock = kwargs.get("mock") - - if robot_type is None: - raise ValueError("The 'robot_type' must be an argument of the test function.") - if request is None: - raise ValueError("The 'request' fixture must be an argument of the test function.") - if mock is None: - raise ValueError("The 'mock' variable must be an argument of the test function.") - - # Run test with a real robot. Skip test if robot connection fails. - if not mock and not request.getfixturevalue("is_robot_available"): - pytest.skip(f"A {robot_type} robot is not available.") - - return func(*args, **kwargs) - - return wrapper - - -def require_camera(func): - @wraps(func) - def wrapper(*args, **kwargs): - # Access the pytest request context to get the is_camera_available fixture - request = kwargs.get("request") - camera_type = kwargs.get("camera_type") - mock = kwargs.get("mock") - - if request is None: - raise ValueError("The 'request' fixture must be an argument of the test function.") - if camera_type is None: - raise ValueError("The 'camera_type' must be an argument of the test function.") - if mock is None: - raise ValueError("The 'mock' variable must be an argument of the test function.") - - if not mock and not request.getfixturevalue("is_camera_available"): - pytest.skip(f"A {camera_type} camera is not available.") - - return func(*args, **kwargs) - - return wrapper - - -def require_motor(func): - @wraps(func) - def wrapper(*args, **kwargs): - # Access the pytest request context to get the is_motor_available fixture - request = kwargs.get("request") - motor_type = kwargs.get("motor_type") - mock = kwargs.get("mock") - - if request is None: - raise ValueError("The 'request' fixture must be an argument of the test function.") - if motor_type is None: - raise ValueError("The 'motor_type' must be an argument of the test function.") - if mock is None: - raise ValueError("The 'mock' variable must be an argument of the test function.") - - if not mock and not request.getfixturevalue("is_motor_available"): - pytest.skip(f"A {motor_type} motor is not available.") - - return func(*args, **kwargs) - - return wrapper - - -def mock_calibration_dir(calibration_dir): - # TODO(rcadene): remove this hack - # calibration file produced with Moss v1, but works with Koch, Koch bimanual and SO-100 - example_calib = { - "homing_offset": [-1416, -845, 2130, 2872, 1950, -2211], - "drive_mode": [0, 0, 1, 1, 1, 0], - "start_pos": [1442, 843, 2166, 2849, 1988, 1835], - "end_pos": [2440, 1869, -1106, -1848, -926, 3235], - "calib_mode": ["DEGREE", "DEGREE", "DEGREE", "DEGREE", "DEGREE", "LINEAR"], - "motor_names": ["shoulder_pan", "shoulder_lift", "elbow_flex", "wrist_flex", "wrist_roll", "gripper"], - } - Path(str(calibration_dir)).mkdir(parents=True, exist_ok=True) - with open(calibration_dir / "main_follower.json", "w") as f: - json.dump(example_calib, f) - with open(calibration_dir / "main_leader.json", "w") as f: - json.dump(example_calib, f) - with open(calibration_dir / "left_follower.json", "w") as f: - json.dump(example_calib, f) - with open(calibration_dir / "left_leader.json", "w") as f: - json.dump(example_calib, f) - with open(calibration_dir / "right_follower.json", "w") as f: - json.dump(example_calib, f) - with open(calibration_dir / "right_leader.json", "w") as f: - json.dump(example_calib, f) - - -# TODO(rcadene, aliberts): remove this dark pattern that overrides -def make_camera(camera_type: str, **kwargs) -> Camera: - if camera_type == "opencv": - camera_index = kwargs.pop("camera_index", OPENCV_CAMERA_INDEX) - return make_camera_device(camera_type, camera_index=camera_index, **kwargs) - - elif camera_type == "intelrealsense": - serial_number = kwargs.pop("serial_number", INTELREALSENSE_SERIAL_NUMBER) - return make_camera_device(camera_type, serial_number=serial_number, **kwargs) - else: - raise ValueError(f"The camera type '{camera_type}' is not valid.") - - -# TODO(rcadene, aliberts): remove this dark pattern that overrides -def make_motors_bus(motor_type: str, **kwargs) -> MotorsBus: - if motor_type == "dynamixel": - port = kwargs.pop("port", DYNAMIXEL_PORT) - motors = kwargs.pop("motors", DYNAMIXEL_MOTORS) - return make_motors_bus_device(motor_type, port=port, motors=motors, **kwargs) - - elif motor_type == "feetech": - port = kwargs.pop("port", FEETECH_PORT) - motors = kwargs.pop("motors", FEETECH_MOTORS) - return make_motors_bus_device(motor_type, port=port, motors=motors, **kwargs) - - else: - raise ValueError(f"The motor type '{motor_type}' is not valid.") diff --git a/tests/utils/test_encoding_utils.py b/tests/utils/test_encoding_utils.py new file mode 100644 index 000000000..c21c8081e --- /dev/null +++ b/tests/utils/test_encoding_utils.py @@ -0,0 +1,155 @@ +import pytest + +from lerobot.common.utils.encoding_utils import ( + decode_sign_magnitude, + decode_twos_complement, + encode_sign_magnitude, + encode_twos_complement, +) + + +@pytest.mark.parametrize( + "value, sign_bit_index, expected", + [ + (5, 4, 5), + (0, 4, 0), + (7, 3, 7), + (-1, 4, 17), + (-8, 4, 24), + (-3, 3, 11), + ], +) +def test_encode_sign_magnitude(value, sign_bit_index, expected): + assert encode_sign_magnitude(value, sign_bit_index) == expected + + +@pytest.mark.parametrize( + "encoded, sign_bit_index, expected", + [ + (5, 4, 5), + (0, 4, 0), + (7, 3, 7), + (17, 4, -1), + (24, 4, -8), + (11, 3, -3), + ], +) +def test_decode_sign_magnitude(encoded, sign_bit_index, expected): + assert decode_sign_magnitude(encoded, sign_bit_index) == expected + + +@pytest.mark.parametrize( + "encoded, sign_bit_index", + [ + (16, 4), + (-9, 3), + ], +) +def test_encode_raises_on_overflow(encoded, sign_bit_index): + with pytest.raises(ValueError): + encode_sign_magnitude(encoded, sign_bit_index) + + +def test_encode_decode_sign_magnitude(): + for sign_bit_index in range(2, 6): + max_val = (1 << sign_bit_index) - 1 + for value in range(-max_val, max_val + 1): + encoded = encode_sign_magnitude(value, sign_bit_index) + decoded = decode_sign_magnitude(encoded, sign_bit_index) + assert decoded == value, f"Failed at value={value}, index={sign_bit_index}" + + +@pytest.mark.parametrize( + "value, n_bytes, expected", + [ + (0, 1, 0), + (5, 1, 5), + (-1, 1, 255), + (-128, 1, 128), + (-2, 1, 254), + (127, 1, 127), + (0, 2, 0), + (5, 2, 5), + (-1, 2, 65_535), + (-32_768, 2, 32_768), + (-2, 2, 65_534), + (32_767, 2, 32_767), + (0, 4, 0), + (5, 4, 5), + (-1, 4, 4_294_967_295), + (-2_147_483_648, 4, 2_147_483_648), + (-2, 4, 4_294_967_294), + (2_147_483_647, 4, 2_147_483_647), + ], +) +def test_encode_twos_complement(value, n_bytes, expected): + assert encode_twos_complement(value, n_bytes) == expected + + +@pytest.mark.parametrize( + "value, n_bytes, expected", + [ + (0, 1, 0), + (5, 1, 5), + (255, 1, -1), + (128, 1, -128), + (254, 1, -2), + (127, 1, 127), + (0, 2, 0), + (5, 2, 5), + (65_535, 2, -1), + (32_768, 2, -32_768), + (65_534, 2, -2), + (32_767, 2, 32_767), + (0, 4, 0), + (5, 4, 5), + (4_294_967_295, 4, -1), + (2_147_483_648, 4, -2_147_483_648), + (4_294_967_294, 4, -2), + (2_147_483_647, 4, 2_147_483_647), + ], +) +def test_decode_twos_complement(value, n_bytes, expected): + assert decode_twos_complement(value, n_bytes) == expected + + +@pytest.mark.parametrize( + "value, n_bytes", + [ + (-129, 1), + (128, 1), + (-32_769, 2), + (32_768, 2), + (-2_147_483_649, 4), + (2_147_483_648, 4), + ], +) +def test_encode_twos_complement_out_of_range(value, n_bytes): + with pytest.raises(ValueError): + encode_twos_complement(value, n_bytes) + + +@pytest.mark.parametrize( + "value, n_bytes", + [ + (-128, 1), + (-1, 1), + (0, 1), + (1, 1), + (127, 1), + (-32_768, 2), + (-1, 2), + (0, 2), + (1, 2), + (32_767, 2), + (-2_147_483_648, 4), + (-1, 4), + (0, 4), + (1, 4), + (2_147_483_647, 4), + ], +) +def test_encode_decode_twos_complement(value, n_bytes): + encoded = encode_twos_complement(value, n_bytes) + decoded = decode_twos_complement(encoded, n_bytes) + assert decoded == value, f"Failed at value={value}, n_bytes={n_bytes}" diff --git a/tests/utils/test_process.py b/tests/utils/test_process.py new file mode 100644 index 000000000..054a8593a --- /dev/null +++ b/tests/utils/test_process.py @@ -0,0 +1,112 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import multiprocessing +import os +import signal +import threading +from unittest.mock import patch + +import pytest + +from lerobot.common.utils.process import ProcessSignalHandler + + +# Fixture to reset shutdown_event_counter and original signal handlers before and after each test +@pytest.fixture(autouse=True) +def reset_globals_and_handlers(): + # Store original signal handlers + original_handlers = { + sig: signal.getsignal(sig) + for sig in [signal.SIGINT, signal.SIGTERM, signal.SIGHUP, signal.SIGQUIT] + if hasattr(signal, sig.name) + } + + yield + + # Restore original signal handlers + for sig, handler in original_handlers.items(): + signal.signal(sig, handler) + + +def test_setup_process_handlers_event_with_threads(): + """Test that setup_process_handlers returns the correct event type.""" + handler = ProcessSignalHandler(use_threads=True) + shutdown_event = handler.shutdown_event + assert isinstance(shutdown_event, threading.Event), "Should be a threading.Event" + assert not shutdown_event.is_set(), "Event should initially be unset" + + +def test_setup_process_handlers_event_with_processes(): + """Test that setup_process_handlers returns the correct event type.""" + handler = ProcessSignalHandler(use_threads=False) + shutdown_event = handler.shutdown_event + assert isinstance(shutdown_event, type(multiprocessing.Event())), "Should be a multiprocessing.Event" + assert not shutdown_event.is_set(), "Event should initially be unset" + + +@pytest.mark.parametrize("use_threads", [True, False]) +@pytest.mark.parametrize( + "sig", + [ + signal.SIGINT, + signal.SIGTERM, + # SIGHUP and SIGQUIT are not reliably available on all platforms (e.g. Windows) + pytest.param( + signal.SIGHUP, + marks=pytest.mark.skipif(not hasattr(signal, "SIGHUP"), reason="SIGHUP not available"), + ), + pytest.param( + signal.SIGQUIT, + marks=pytest.mark.skipif(not hasattr(signal, "SIGQUIT"), reason="SIGQUIT not available"), + ), + ], +) +def test_signal_handler_sets_event(use_threads, sig): + """Test that the signal handler sets the event on receiving a signal.""" + handler = ProcessSignalHandler(use_threads=use_threads) + shutdown_event = handler.shutdown_event + + assert handler.counter == 0 + + os.kill(os.getpid(), sig) + + # In some environments, the signal might take a moment to be handled. + shutdown_event.wait(timeout=1.0) + + assert shutdown_event.is_set(), f"Event should be set after receiving signal {sig}" + + # Ensure the internal counter was incremented + assert handler.counter == 1 + + +@pytest.mark.parametrize("use_threads", [True, False]) +@patch("sys.exit") +def test_force_shutdown_on_second_signal(mock_sys_exit, use_threads): + """Test that a second signal triggers a force shutdown.""" + handler = ProcessSignalHandler(use_threads=use_threads) + + os.kill(os.getpid(), signal.SIGINT) + # Give a moment for the first signal to be processed + import time + + time.sleep(0.1) + os.kill(os.getpid(), signal.SIGINT) + + time.sleep(0.1) + + assert handler.counter == 2 + mock_sys_exit.assert_called_once_with(1) diff --git a/tests/utils/test_queue.py b/tests/utils/test_queue.py new file mode 100644 index 000000000..863231e82 --- /dev/null +++ b/tests/utils/test_queue.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import threading +import time +from queue import Queue + +from lerobot.common.utils.queue import get_last_item_from_queue + + +def test_get_last_item_single_item(): + """Test getting the last item when queue has only one item.""" + queue = Queue() + queue.put("single_item") + + result = get_last_item_from_queue(queue) + + assert result == "single_item" + assert queue.empty() + + +def test_get_last_item_multiple_items(): + """Test getting the last item when queue has multiple items.""" + queue = Queue() + items = ["first", "second", "third", "fourth", "last"] + + for item in items: + queue.put(item) + + result = get_last_item_from_queue(queue) + + assert result == "last" + assert queue.empty() + + +def test_get_last_item_different_types(): + """Test with different data types in the queue.""" + queue = Queue() + items = [1, 2.5, "string", {"key": "value"}, [1, 2, 3], ("tuple", "data")] + + for item in items: + queue.put(item) + + result = get_last_item_from_queue(queue) + + assert result == ("tuple", "data") + assert queue.empty() + + +def test_get_last_item_maxsize_queue(): + """Test with a queue that has a maximum size.""" + queue = Queue(maxsize=5) + + # Fill the queue + for i in range(5): + queue.put(i) + + # Give the queue time to fill + time.sleep(0.1) + + result = get_last_item_from_queue(queue) + + assert result == 4 + assert queue.empty() + + +def test_get_last_item_with_none_values(): + """Test with None values in the queue.""" + queue = Queue() + items = [1, None, 2, None, 3] + + for item in items: + queue.put(item) + + # Give the queue time to fill + time.sleep(0.1) + + result = get_last_item_from_queue(queue) + + assert result == 3 + assert queue.empty() + + +def test_get_last_item_blocking_timeout(): + """Test get_last_item_from_queue returns None on timeout.""" + queue = Queue() + result = get_last_item_from_queue(queue, block=True, timeout=0.1) + assert result is None + + +def test_get_last_item_non_blocking_empty(): + """Test get_last_item_from_queue with block=False on an empty queue returns None.""" + queue = Queue() + result = get_last_item_from_queue(queue, block=False) + assert result is None + + +def test_get_last_item_non_blocking_success(): + """Test get_last_item_from_queue with block=False on a non-empty queue.""" + queue = Queue() + items = ["first", "second", "last"] + for item in items: + queue.put(item) + + # Give the queue time to fill + time.sleep(0.1) + + result = get_last_item_from_queue(queue, block=False) + assert result == "last" + assert queue.empty() + + +def test_get_last_item_blocking_waits_for_item(): + """Test that get_last_item_from_queue waits for an item if block=True.""" + queue = Queue() + result = [] + + def producer(): + queue.put("item1") + queue.put("item2") + + def consumer(): + # This will block until the producer puts the first item + item = get_last_item_from_queue(queue, block=True, timeout=0.2) + result.append(item) + + producer_thread = threading.Thread(target=producer) + consumer_thread = threading.Thread(target=consumer) + + producer_thread.start() + consumer_thread.start() + + producer_thread.join() + consumer_thread.join() + + assert result == ["item2"] + assert queue.empty() diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py new file mode 100644 index 000000000..f7a055b20 --- /dev/null +++ b/tests/utils/test_replay_buffer.py @@ -0,0 +1,682 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from typing import Callable + +import pytest +import torch + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.utils.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized +from tests.fixtures.constants import DUMMY_REPO_ID + + +def state_dims() -> list[str]: + return ["observation.image", "observation.state"] + + +@pytest.fixture +def replay_buffer() -> ReplayBuffer: + return create_empty_replay_buffer() + + +def clone_state(state: dict) -> dict: + return {k: v.clone() for k, v in state.items()} + + +def create_empty_replay_buffer( + optimize_memory: bool = False, + use_drq: bool = False, + image_augmentation_function: Callable | None = None, +) -> ReplayBuffer: + buffer_capacity = 10 + device = "cpu" + return ReplayBuffer( + buffer_capacity, + device, + state_dims(), + optimize_memory=optimize_memory, + use_drq=use_drq, + image_augmentation_function=image_augmentation_function, + ) + + +def create_random_image() -> torch.Tensor: + return torch.rand(3, 84, 84) + + +def create_dummy_transition() -> dict: + return { + "observation.image": create_random_image(), + "action": torch.randn(4), + "reward": torch.tensor(1.0), + "observation.state": torch.randn( + 10, + ), + "done": torch.tensor(False), + "truncated": torch.tensor(False), + "complementary_info": {}, + } + + +def create_dataset_from_replay_buffer(tmp_path) -> tuple[LeRobotDataset, ReplayBuffer]: + dummy_state_1 = create_dummy_state() + dummy_action_1 = create_dummy_action() + + dummy_state_2 = create_dummy_state() + dummy_action_2 = create_dummy_action() + + dummy_state_3 = create_dummy_state() + dummy_action_3 = create_dummy_action() + + dummy_state_4 = create_dummy_state() + dummy_action_4 = create_dummy_action() + + replay_buffer = create_empty_replay_buffer() + replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False) + replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False) + replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True) + replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True) + + root = tmp_path / "test" + return (replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root), replay_buffer) + + +def create_dummy_state() -> dict: + return { + "observation.image": create_random_image(), + "observation.state": torch.randn( + 10, + ), + } + + +def get_tensor_memory_consumption(tensor): + return tensor.nelement() * tensor.element_size() + + +def get_tensors_memory_consumption(obj, visited_addresses): + total_size = 0 + + address = id(obj) + if address in visited_addresses: + return 0 + + visited_addresses.add(address) + + if isinstance(obj, torch.Tensor): + return get_tensor_memory_consumption(obj) + elif isinstance(obj, (list, tuple)): + for item in obj: + total_size += get_tensors_memory_consumption(item, visited_addresses) + elif isinstance(obj, dict): + for value in obj.values(): + total_size += get_tensors_memory_consumption(value, visited_addresses) + elif hasattr(obj, "__dict__"): + # It's an object, we need to get the size of the attributes + for _, attr in vars(obj).items(): + total_size += get_tensors_memory_consumption(attr, visited_addresses) + + return total_size + + +def get_object_memory(obj): + # Track visited addresses to avoid infinite loops + # and cases when two properties point to the same object + visited_addresses = set() + + # Get the size of the object in bytes + total_size = sys.getsizeof(obj) + + # Get the size of the tensor attributes + total_size += get_tensors_memory_consumption(obj, visited_addresses) + + return total_size + + +def create_dummy_action() -> torch.Tensor: + return torch.randn(4) + + +def dict_properties() -> list: + return ["state", "next_state"] + + +@pytest.fixture +def dummy_state() -> dict: + return create_dummy_state() + + +@pytest.fixture +def next_dummy_state() -> dict: + return create_dummy_state() + + +@pytest.fixture +def dummy_action() -> torch.Tensor: + return torch.randn(4) + + +def test_empty_buffer_sample_raises_error(replay_buffer): + assert len(replay_buffer) == 0, "Replay buffer should be empty." + assert replay_buffer.capacity == 10, "Replay buffer capacity should be 10." + with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"): + replay_buffer.sample(1) + + +def test_zero_capacity_buffer_raises_error(): + with pytest.raises(ValueError, match="Capacity must be greater than 0."): + ReplayBuffer(0, "cpu", ["observation", "next_observation"]) + + +def test_add_transition(replay_buffer, dummy_state, dummy_action): + replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False) + assert len(replay_buffer) == 1, "Replay buffer should have one transition after adding." + assert torch.equal(replay_buffer.actions[0], dummy_action), ( + "Action should be equal to the first transition." + ) + assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the first transition." + assert not replay_buffer.dones[0], "Done should be False for the first transition." + assert not replay_buffer.truncateds[0], "Truncated should be False for the first transition." + + for dim in state_dims(): + assert torch.equal(replay_buffer.states[dim][0], dummy_state[dim]), ( + "Observation should be equal to the first transition." + ) + assert torch.equal(replay_buffer.next_states[dim][0], dummy_state[dim]), ( + "Next observation should be equal to the first transition." + ) + + +def test_add_over_capacity(): + replay_buffer = ReplayBuffer(2, "cpu", ["observation", "next_observation"]) + dummy_state_1 = create_dummy_state() + dummy_action_1 = create_dummy_action() + + dummy_state_2 = create_dummy_state() + dummy_action_2 = create_dummy_action() + + dummy_state_3 = create_dummy_state() + dummy_action_3 = create_dummy_action() + + replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False) + replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False) + replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True) + + assert len(replay_buffer) == 2, "Replay buffer should have 2 transitions after adding 3." + + for dim in state_dims(): + assert torch.equal(replay_buffer.states[dim][0], dummy_state_3[dim]), ( + "Observation should be equal to the first transition." + ) + assert torch.equal(replay_buffer.next_states[dim][0], dummy_state_3[dim]), ( + "Next observation should be equal to the first transition." + ) + + assert torch.equal(replay_buffer.actions[0], dummy_action_3), ( + "Action should be equal to the last transition." + ) + assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the last transition." + assert replay_buffer.dones[0], "Done should be True for the first transition." + assert replay_buffer.truncateds[0], "Truncated should be True for the first transition." + + +def test_sample_from_empty_buffer(replay_buffer): + with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"): + replay_buffer.sample(1) + + +def test_sample_with_1_transition(replay_buffer, dummy_state, next_dummy_state, dummy_action): + replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False) + got_batch_transition = replay_buffer.sample(1) + + expected_batch_transition = BatchTransition( + state=clone_state(dummy_state), + action=dummy_action.clone(), + reward=1.0, + next_state=clone_state(next_dummy_state), + done=False, + truncated=False, + ) + + for buffer_property in dict_properties(): + for k, v in expected_batch_transition[buffer_property].items(): + got_state = got_batch_transition[buffer_property][k] + + assert got_state.shape[0] == 1, f"{k} should have 1 transition." + assert got_state.device.type == "cpu", f"{k} should be on cpu." + + assert torch.equal(got_state[0], v), f"{k} should be equal to the expected batch transition." + + for key, _value in expected_batch_transition.items(): + if key in dict_properties(): + continue + + got_value = got_batch_transition[key] + + v_tensor = expected_batch_transition[key] + if not isinstance(v_tensor, torch.Tensor): + v_tensor = torch.tensor(v_tensor) + + assert got_value.shape[0] == 1, f"{key} should have 1 transition." + assert got_value.device.type == "cpu", f"{key} should be on cpu." + assert torch.equal(got_value[0], v_tensor), f"{key} should be equal to the expected batch transition." + + +def test_sample_with_batch_bigger_than_buffer_size( + replay_buffer, dummy_state, next_dummy_state, dummy_action +): + replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False) + got_batch_transition = replay_buffer.sample(10) + + expected_batch_transition = BatchTransition( + state=dummy_state, + action=dummy_action, + reward=1.0, + next_state=next_dummy_state, + done=False, + truncated=False, + ) + + for buffer_property in dict_properties(): + for k in expected_batch_transition[buffer_property]: + got_state = got_batch_transition[buffer_property][k] + + assert got_state.shape[0] == 1, f"{k} should have 1 transition." + + for key in expected_batch_transition: + if key in dict_properties(): + continue + + got_value = got_batch_transition[key] + assert got_value.shape[0] == 1, f"{key} should have 1 transition." + + +def test_sample_batch(replay_buffer): + dummy_state_1 = create_dummy_state() + dummy_action_1 = create_dummy_action() + + dummy_state_2 = create_dummy_state() + dummy_action_2 = create_dummy_action() + + dummy_state_3 = create_dummy_state() + dummy_action_3 = create_dummy_action() + + dummy_state_4 = create_dummy_state() + dummy_action_4 = create_dummy_action() + + replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False) + replay_buffer.add(dummy_state_2, dummy_action_2, 2.0, dummy_state_2, False, False) + replay_buffer.add(dummy_state_3, dummy_action_3, 3.0, dummy_state_3, True, True) + replay_buffer.add(dummy_state_4, dummy_action_4, 4.0, dummy_state_4, True, True) + + dummy_states = [dummy_state_1, dummy_state_2, dummy_state_3, dummy_state_4] + dummy_actions = [dummy_action_1, dummy_action_2, dummy_action_3, dummy_action_4] + + got_batch_transition = replay_buffer.sample(3) + + for buffer_property in dict_properties(): + for k in got_batch_transition[buffer_property]: + got_state = got_batch_transition[buffer_property][k] + + assert got_state.shape[0] == 3, f"{k} should have 3 transition." + + for got_state_item in got_state: + assert any(torch.equal(got_state_item, dummy_state[k]) for dummy_state in dummy_states), ( + f"{k} should be equal to one of the dummy states." + ) + + for got_action_item in got_batch_transition["action"]: + assert any(torch.equal(got_action_item, dummy_action) for dummy_action in dummy_actions), ( + "Actions should be equal to the dummy actions." + ) + + for k in got_batch_transition: + if k in dict_properties() or k == "complementary_info": + continue + + got_value = got_batch_transition[k] + assert got_value.shape[0] == 3, f"{k} should have 3 transition." + + +def test_to_lerobot_dataset_with_empty_buffer(replay_buffer): + with pytest.raises(ValueError, match="The replay buffer is empty. Cannot convert to a dataset."): + replay_buffer.to_lerobot_dataset("dummy_repo") + + +def test_to_lerobot_dataset(tmp_path): + ds, buffer = create_dataset_from_replay_buffer(tmp_path) + + assert len(ds) == len(buffer), "Dataset should have the same size as the Replay Buffer" + assert ds.fps == 1, "FPS should be 1" + assert ds.repo_id == "dummy/repo", "The dataset should have `dummy/repo` repo id" + + for dim in state_dims(): + assert dim in ds.features + assert ds.features[dim]["shape"] == buffer.states[dim][0].shape + + assert ds.num_episodes == 2 + assert ds.num_frames == 4 + + for j, value in enumerate(ds): + print(torch.equal(value["observation.image"], buffer.next_states["observation.image"][j])) + + for i in range(len(ds)): + for feature, value in ds[i].items(): + if feature == "action": + assert torch.equal(value, buffer.actions[i]) + elif feature == "next.reward": + assert torch.equal(value, buffer.rewards[i]) + elif feature == "next.done": + assert torch.equal(value, buffer.dones[i]) + elif feature == "observation.image": + # Tenssor -> numpy is not precise, so we have some diff there + # TODO: Check and fix it + torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003) + elif feature == "observation.state": + assert torch.equal(value, buffer.states["observation.state"][i]) + + +def test_from_lerobot_dataset(tmp_path): + dummy_state_1 = create_dummy_state() + dummy_action_1 = create_dummy_action() + + dummy_state_2 = create_dummy_state() + dummy_action_2 = create_dummy_action() + + dummy_state_3 = create_dummy_state() + dummy_action_3 = create_dummy_action() + + dummy_state_4 = create_dummy_state() + dummy_action_4 = create_dummy_action() + + replay_buffer = create_empty_replay_buffer() + replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False) + replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False) + replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True) + replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True) + + root = tmp_path / "test" + ds = replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root) + + reconverted_buffer = ReplayBuffer.from_lerobot_dataset( + ds, state_keys=list(state_dims()), device="cpu", capacity=replay_buffer.capacity, use_drq=False + ) + + # Check only the part of the buffer that's actually filled with data + assert torch.equal( + reconverted_buffer.actions[: len(replay_buffer)], + replay_buffer.actions[: len(replay_buffer)], + ), "Actions from converted buffer should be equal to the original replay buffer." + assert torch.equal( + reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)] + ), "Rewards from converted buffer should be equal to the original replay buffer." + assert torch.equal( + reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)] + ), "Dones from converted buffer should be equal to the original replay buffer." + + # Lerobot DS haven't supported truncateds yet + expected_truncateds = torch.zeros(len(replay_buffer)).bool() + assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), ( + "Truncateds from converted buffer should be equal False" + ) + + assert torch.equal( + replay_buffer.states["observation.state"][: len(replay_buffer)], + reconverted_buffer.states["observation.state"][: len(replay_buffer)], + ), "State should be the same after converting to dataset and return back" + + for i in range(4): + torch.testing.assert_close( + replay_buffer.states["observation.image"][i], + reconverted_buffer.states["observation.image"][i], + rtol=0.4, + atol=0.004, + ) + + # The 2, 3 frames have done flag, so their values will be equal to the current state + for i in range(2): + # In the current implementation we take the next state from the `states` and ignore `next_states` + next_index = (i + 1) % 4 + + torch.testing.assert_close( + replay_buffer.states["observation.image"][next_index], + reconverted_buffer.next_states["observation.image"][i], + rtol=0.4, + atol=0.004, + ) + + for i in range(2, 4): + assert torch.equal( + replay_buffer.states["observation.state"][i], + reconverted_buffer.next_states["observation.state"][i], + ) + + +def test_buffer_sample_alignment(): + # Initialize buffer + buffer = ReplayBuffer(capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu") + + # Fill buffer with patterned data + for i in range(100): + signature = float(i) / 100.0 + state = {"state_value": torch.tensor([[signature]]).float()} + action = torch.tensor([[2.0 * signature]]).float() + reward = 3.0 * signature + + is_end = (i + 1) % 10 == 0 + if is_end: + next_state = {"state_value": torch.tensor([[signature]]).float()} + done = True + else: + next_signature = float(i + 1) / 100.0 + next_state = {"state_value": torch.tensor([[next_signature]]).float()} + done = False + + buffer.add(state, action, reward, next_state, done, False) + + # Sample and verify + batch = buffer.sample(50) + + for i in range(50): + state_sig = batch["state"]["state_value"][i].item() + action_val = batch["action"][i].item() + reward_val = batch["reward"][i].item() + next_state_sig = batch["next_state"]["state_value"][i].item() + is_done = batch["done"][i].item() > 0.5 + + # Verify relationships + assert abs(action_val - 2.0 * state_sig) < 1e-4, ( + f"Action {action_val} should be 2x state signature {state_sig}" + ) + + assert abs(reward_val - 3.0 * state_sig) < 1e-4, ( + f"Reward {reward_val} should be 3x state signature {state_sig}" + ) + + if is_done: + assert abs(next_state_sig - state_sig) < 1e-4, ( + f"For done states, next_state {next_state_sig} should equal state {state_sig}" + ) + else: + # Either it's the next sequential state (+0.01) or same state (for episode boundaries) + valid_next = ( + abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4 + ) + assert valid_next, ( + f"Next state {next_state_sig} should be either state+0.01 or same as state {state_sig}" + ) + + +def test_memory_optimization(): + dummy_state_1 = create_dummy_state() + dummy_action_1 = create_dummy_action() + + dummy_state_2 = create_dummy_state() + dummy_action_2 = create_dummy_action() + + dummy_state_3 = create_dummy_state() + dummy_action_3 = create_dummy_action() + + dummy_state_4 = create_dummy_state() + dummy_action_4 = create_dummy_action() + + replay_buffer = create_empty_replay_buffer() + replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False) + replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False) + replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False) + replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True) + + optimized_replay_buffer = create_empty_replay_buffer(True) + optimized_replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False) + optimized_replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False) + optimized_replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False) + optimized_replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, None, True, True) + + assert get_object_memory(optimized_replay_buffer) < get_object_memory(replay_buffer), ( + "Optimized replay buffer should be smaller than the original replay buffer" + ) + + +def test_check_image_augmentations_with_drq_and_dummy_image_augmentation_function(dummy_state, dummy_action): + def dummy_image_augmentation_function(x): + return torch.ones_like(x) * 10 + + replay_buffer = create_empty_replay_buffer( + use_drq=True, image_augmentation_function=dummy_image_augmentation_function + ) + + replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False) + + sampled_transitions = replay_buffer.sample(1) + assert torch.all(sampled_transitions["state"]["observation.image"] == 10), ( + "Image augmentations should be applied" + ) + assert torch.all(sampled_transitions["next_state"]["observation.image"] == 10), ( + "Image augmentations should be applied" + ) + + +def test_check_image_augmentations_with_drq_and_default_image_augmentation_function( + dummy_state, dummy_action +): + replay_buffer = create_empty_replay_buffer(use_drq=True) + + replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False) + + # Let's check that it doesn't fail and shapes are correct + sampled_transitions = replay_buffer.sample(1) + assert sampled_transitions["state"]["observation.image"].shape == (1, 3, 84, 84) + assert sampled_transitions["next_state"]["observation.image"].shape == (1, 3, 84, 84) + + +def test_random_crop_vectorized_basic(): + # Create a batch of 2 images with known patterns + batch_size, channels, height, width = 2, 3, 10, 8 + images = torch.zeros((batch_size, channels, height, width)) + + # Fill with unique values for testing + for b in range(batch_size): + images[b] = b + 1 + + crop_size = (6, 4) # Smaller than original + cropped = random_crop_vectorized(images, crop_size) + + # Check output shape + assert cropped.shape == (batch_size, channels, *crop_size) + + # Check that values are preserved (should be either 1s or 2s for respective batches) + assert torch.all(cropped[0] == 1) + assert torch.all(cropped[1] == 2) + + +def test_random_crop_vectorized_invalid_size(): + images = torch.zeros((2, 3, 10, 8)) + + # Test crop size larger than image + with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"): + random_crop_vectorized(images, (12, 8)) + + with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"): + random_crop_vectorized(images, (10, 10)) + + +def _populate_buffer_for_async_test(capacity: int = 10) -> ReplayBuffer: + """Create a small buffer with deterministic 3×128×128 images and 11-D state.""" + buffer = ReplayBuffer( + capacity=capacity, + device="cpu", + state_keys=["observation.image", "observation.state"], + storage_device="cpu", + ) + + for i in range(capacity): + img = torch.ones(3, 128, 128) * i + state_vec = torch.arange(11).float() + i + state = { + "observation.image": img, + "observation.state": state_vec, + } + buffer.add( + state=state, + action=torch.tensor([0.0]), + reward=0.0, + next_state=state, + done=False, + truncated=False, + ) + return buffer + + +def test_async_iterator_shapes_basic(): + buffer = _populate_buffer_for_async_test() + batch_size = 2 + iterator = buffer.get_iterator(batch_size=batch_size, async_prefetch=True, queue_size=1) + batch = next(iterator) + + images = batch["state"]["observation.image"] + states = batch["state"]["observation.state"] + + assert images.shape == (batch_size, 3, 128, 128) + assert states.shape == (batch_size, 11) + + next_images = batch["next_state"]["observation.image"] + next_states = batch["next_state"]["observation.state"] + + assert next_images.shape == (batch_size, 3, 128, 128) + assert next_states.shape == (batch_size, 11) + + +def test_async_iterator_multiple_iterations(): + buffer = _populate_buffer_for_async_test() + batch_size = 2 + iterator = buffer.get_iterator(batch_size=batch_size, async_prefetch=True, queue_size=2) + + for _ in range(5): + batch = next(iterator) + images = batch["state"]["observation.image"] + states = batch["state"]["observation.state"] + assert images.shape == (batch_size, 3, 128, 128) + assert states.shape == (batch_size, 11) + + next_images = batch["next_state"]["observation.image"] + next_states = batch["next_state"]["observation.state"] + assert next_images.shape == (batch_size, 3, 128, 128) + assert next_states.shape == (batch_size, 11) + + # Ensure iterator can be disposed without blocking + del iterator