diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 22f1ee3d7..d37b1a92f 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -30,7 +30,7 @@ pytest -sx tests/test_stuff.py::test_something ``` ```bash -python -m lerobot.scripts.train --some.option=true +lerobot-train --some.option=true ``` ## SECTION TO REMOVE BEFORE SUBMITTING YOUR PR diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml index b42b92f6b..03f26a792 100644 --- a/.github/workflows/nightly.yml +++ b/.github/workflows/nightly.yml @@ -29,8 +29,8 @@ on: env: UV_VERSION: "0.8.0" PYTHON_VERSION: "3.10" - DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-gpu:latest - DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-cpu:latest + DOCKER_IMAGE_NAME_CPU: huggingface/lerobot-cpu:latest + DOCKER_IMAGE_NAME_GPU: huggingface/lerobot-gpu:latest # Ensures that only the latest commit is built, canceling older runs. concurrency: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 32c1c605a..67aa5186b 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -19,6 +19,11 @@ on: tags: - 'v*.*.*' # Trigger on tags like v0.1.0, v1.0.0 +# Sets up the environment variables +env: + UV_VERSION: "0.8.0" + PYTHON_VERSION: "3.10" + jobs: # This job builds the Python package and publishes it to PyPI build-and-publish: @@ -50,6 +55,7 @@ jobs: VERSION_NUMBER=${VERSION#v} echo "tag_version=$VERSION_NUMBER" >> $GITHUB_OUTPUT - name: Check if version matches pyproject.toml + if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-') # zizmor: ignore[template-injection] run: | TAG_VERSION=${{ steps.extract_info.outputs.tag_version }} @@ -86,13 +92,29 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # zizmor: ignore[template-injection] - run: gh release create ${{ github.ref_name }} --release-name "Release ${{ github.ref_name }}" --generate-notes ./dist/* + run: | + gh release create ${{ github.ref_name }} \ + --title "Release ${{ github.ref_name }}" \ + --generate-notes \ + --draft=$([[ "${{ github.ref_name }}" == *-* ]] && echo true || echo false) \ + --prerelease=$([[ "${{ github.ref_name }}" == *-* ]] && echo true || echo false) \ + ./dist/* - - name: Publish to PyPI - if: startsWith(github.ref, 'refs/tags/v') + - name: Publish to TestPyPI for pre-releases + # True for tags like 'v0.2.0-rc1' + if: startsWith(github.ref, 'refs/tags/v') && contains(github.ref, '-') uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing] with: - password: ${{ secrets.PYPI_API_TOKEN }} + repository-url: https://test.pypi.org/legacy/ + verbose: true + print-hash: true + + - name: Publish to PyPI + if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-') + uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing] + with: + verbose: true + print-hash: true # This job runs end-to-end tests on the release test-release: @@ -119,15 +141,31 @@ jobs: enable-cache: true version: ${{ env.UV_VERSION }} python-version: ${{ env.PYTHON_VERSION }} + - name: Create uv virtual environment + run: uv venv - name: Install lerobot release - run: uv run pip install lerobot==${{ needs.build-and-publish.outputs.version }} # zizmor: ignore[template-injection] - + # zizmor: ignore[template-injection] + run: | + VERSION="${{ needs.build-and-publish.outputs.version }}" + if [[ "$VERSION" == *-* ]]; then + BASE_VERSION="${VERSION%%-*}" + echo "Installing pre-release version $BASE_VERSION from TestPyPI..." + uv pip install \ + --index-url https://test.pypi.org/simple/ \ + --extra-index-url https://pypi.org/simple \ + --index-strategy unsafe-best-match \ + "lerobot[all]==$BASE_VERSION" + else + echo "Installing release version $VERSION from PyPI..." + uv pip install "lerobot[all]==$VERSION" + fi - name: Check lerobot version - run: uv run lerobot --version + run: uv run python -c "import lerobot; print(lerobot.__version__)" - name: Run end-to-end tests run: uv run make test-end-to-end -# TODO(Steven): Publish draft/pre-release and to test pypi +# TODO(Steven): Publish draft/pre-release and to test pypi weekly +# TODO(Steven): Separate build and publish job # TODO(Steven): Tag documentation with the same version as the package diff --git a/Makefile b/Makefile index 5bfbe76a2..fbe8a5bae 100644 --- a/Makefile +++ b/Makefile @@ -44,7 +44,7 @@ test-end-to-end: ${MAKE} DEVICE=$(DEVICE) test-smolvla-ete-eval test-act-ete-train: - python -m lerobot.scripts.train \ + lerobot-train \ --policy.type=act \ --policy.dim_model=64 \ --policy.n_action_steps=20 \ @@ -68,12 +68,12 @@ test-act-ete-train: --output_dir=tests/outputs/act/ test-act-ete-train-resume: - python -m lerobot.scripts.train \ + lerobot-train \ --config_path=tests/outputs/act/checkpoints/000002/pretrained_model/train_config.json \ --resume=true test-act-ete-eval: - python -m lerobot.scripts.eval \ + lerobot-eval \ --policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \ --policy.device=$(DEVICE) \ --env.type=aloha \ @@ -82,7 +82,7 @@ test-act-ete-eval: --eval.batch_size=1 test-diffusion-ete-train: - python -m lerobot.scripts.train \ + lerobot-train \ --policy.type=diffusion \ --policy.down_dims='[64,128,256]' \ --policy.diffusion_step_embed_dim=32 \ @@ -106,7 +106,7 @@ test-diffusion-ete-train: --output_dir=tests/outputs/diffusion/ test-diffusion-ete-eval: - python -m lerobot.scripts.eval \ + lerobot-eval \ --policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \ --policy.device=$(DEVICE) \ --env.type=pusht \ @@ -115,7 +115,7 @@ test-diffusion-ete-eval: --eval.batch_size=1 test-tdmpc-ete-train: - python -m lerobot.scripts.train \ + lerobot-train \ --policy.type=tdmpc \ --policy.device=$(DEVICE) \ --policy.push_to_hub=false \ @@ -137,7 +137,7 @@ test-tdmpc-ete-train: --output_dir=tests/outputs/tdmpc/ test-tdmpc-ete-eval: - python -m lerobot.scripts.eval \ + lerobot-eval \ --policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \ --policy.device=$(DEVICE) \ --env.type=xarm \ @@ -148,7 +148,7 @@ test-tdmpc-ete-eval: test-smolvla-ete-train: - python -m lerobot.scripts.train \ + lerobot-train \ --policy.type=smolvla \ --policy.n_action_steps=20 \ --policy.chunk_size=20 \ @@ -171,7 +171,7 @@ test-smolvla-ete-train: --output_dir=tests/outputs/smolvla/ test-smolvla-ete-eval: - python -m lerobot.scripts.eval \ + lerobot-eval \ --policy.path=tests/outputs/smolvla/checkpoints/000004/pretrained_model \ --policy.device=$(DEVICE) \ --env.type=aloha \ diff --git a/README.md b/README.md index 1d7cbcad4..b5e666aa8 100644 --- a/README.md +++ b/README.md @@ -1,25 +1,21 @@

- - - - LeRobot, Hugging Face Robotics Library - + LeRobot, Hugging Face Robotics Library

-[![Tests](https://github.com/huggingface/lerobot/actions/workflows/nightly-tests.yml/badge.svg?branch=main)](https://github.com/huggingface/lerobot/actions/workflows/nightly-tests.yml?query=branch%3Amain) -[![Coverage](https://codecov.io/gh/huggingface/lerobot/branch/main/graph/badge.svg?token=TODO)](https://codecov.io/gh/huggingface/lerobot) +[![Tests](https://github.com/huggingface/lerobot/actions/workflows/nightly.yml/badge.svg?branch=main)](https://github.com/huggingface/lerobot/actions/workflows/nightly.yml?query=branch%3Amain) [![Python versions](https://img.shields.io/pypi/pyversions/lerobot)](https://www.python.org/downloads/) [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://github.com/huggingface/lerobot/blob/main/LICENSE) [![Status](https://img.shields.io/pypi/status/lerobot)](https://pypi.org/project/lerobot/) [![Version](https://img.shields.io/pypi/v/lerobot)](https://pypi.org/project/lerobot/) -[![Examples](https://img.shields.io/badge/Examples-green.svg)](https://github.com/huggingface/lerobot/tree/main/examples) -[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-v2.1%20adopted-ff69b4.svg)](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md) +[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-v2.1-ff69b4.svg)](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md) [![Discord](https://dcbadge.vercel.app/api/server/C5P34WJ68S?style=flat)](https://discord.gg/s3KuuzsPFb) + +

@@ -29,10 +25,10 @@
HopeJR robot

Meet HopeJR – A humanoid robot arm and hand for dexterous manipulation!

@@ -51,20 +47,12 @@

-
- SO-101 follower arm - SO-101 leader arm -
+ + + + + +
SO-101 follower armSO-101 leader arm

Meet the updated SO100, the SO-101 – Just €114 per arm!

Train it in minutes with a few simple moves on your laptop.

@@ -76,7 +64,7 @@

Want to take it to the next level? Make your SO-101 mobile by building LeKiwi!

Check out the LeKiwi tutorial and bring your robot to life on wheels.

- LeKiwi mobile robot + LeKiwi mobile robot

@@ -99,9 +87,9 @@ - - - + + + @@ -110,23 +98,11 @@
ACT policy on ALOHA envTDMPC policy on SimXArm envDiffusion policy on PushT envACT policy on ALOHA envTDMPC policy on SimXArm envDiffusion policy on PushT env
ACT policy on ALOHA env
-### Acknowledgment - -- The LeRobot team 🤗 for building SmolVLA [Paper](https://arxiv.org/abs/2506.01844), [Blog](https://huggingface.co/blog/smolvla). -- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io). -- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io). -- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM). -- Thanks to Antonio Loquercio and Ashish Kumar for their early support. -- Thanks to [Seungjae (Jay) Lee](https://sjlee.cc/), [Mahi Shafiullah](https://mahis.life/) and colleagues for open sourcing [VQ-BeT](https://sjlee.cc/vq-bet/) policy and helping us adapt the codebase to our repository. The policy is adapted from [VQ-BeT repo](https://github.com/jayLEE0301/vq_bet_official). - ## Installation -Download our source code: +LeRobot works with Python 3.10+ and PyTorch 2.2+. -```bash -git clone https://github.com/huggingface/lerobot.git -cd lerobot -``` +### Environment Setup Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html): @@ -151,7 +127,18 @@ conda install ffmpeg -c conda-forge > > - _[On Linux only]_ Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`. -Install 🤗 LeRobot: +### Install LeRobot 🤗 + +#### From Source + +First, clone the repository and navigate into the directory: + +```bash +git clone https://github.com/huggingface/lerobot.git +cd lerobot +``` + +Then, install the library in editable mode. This is useful if you plan to contribute to the code. ```bash pip install -e . @@ -172,6 +159,34 @@ For instance, to install 🤗 LeRobot with aloha and pusht, use: pip install -e ".[aloha, pusht]" ``` +### Installation from PyPI + +**Core Library:** +Install the base package with: + +```bash +pip install lerobot +``` + +_This installs only the default dependencies._ + +**Extra Features:** +To install additional functionality, use one of the following: + +```bash +pip install 'lerobot[all]' # All available features +pip install 'lerobot[aloha,pusht]' # Specific features (Aloha & Pusht) +pip install 'lerobot[feetech]' # Feetech motor support +``` + +_Replace `[...]` with your desired features._ + +**Available Tags:** +For a full list of optional dependencies, see: +https://pypi.org/project/lerobot/ + +### Weights & Biases + To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with ```bash @@ -182,7 +197,7 @@ wandb login ### Visualize datasets -Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub. +Check out [example 1](https://github.com/huggingface/lerobot/blob/main/examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub. You can also locally visualize episodes from a dataset on the hub by executing our script from the command line: @@ -212,7 +227,7 @@ Our script can also visualize datasets stored on a distant server. See `python - A dataset in `LeRobotDataset` format is very simple to use. It can be loaded from a repository on the Hugging Face hub or a local folder simply with e.g. `dataset = LeRobotDataset("lerobot/aloha_static_coffee")` and can be indexed into like any Hugging Face and PyTorch dataset. For instance `dataset[0]` will retrieve a single temporal frame from the dataset containing observation(s) and an action as PyTorch tensors ready to be fed to a model. -A specificity of `LeRobotDataset` is that, rather than retrieving a single frame by its index, we can retrieve several frames based on their temporal relationship with the indexed frame, by setting `delta_timestamps` to a list of relative times with respect to the indexed frame. For example, with `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for a given index, 4 frames: 3 "previous" frames 1 second, 0.5 seconds, and 0.2 seconds before the indexed frame, and the indexed frame itself (corresponding to the 0 entry). See example [1_load_lerobot_dataset.py](examples/1_load_lerobot_dataset.py) for more details on `delta_timestamps`. +A specificity of `LeRobotDataset` is that, rather than retrieving a single frame by its index, we can retrieve several frames based on their temporal relationship with the indexed frame, by setting `delta_timestamps` to a list of relative times with respect to the indexed frame. For example, with `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for a given index, 4 frames: 3 "previous" frames 1 second, 0.5 seconds, and 0.2 seconds before the indexed frame, and the indexed frame itself (corresponding to the 0 entry). See example [1_load_lerobot_dataset.py](https://github.com/huggingface/lerobot/blob/main/examples/1_load_lerobot_dataset.py) for more details on `delta_timestamps`. Under the hood, the `LeRobotDataset` format makes use of several ways to serialize data which can be useful to understand if you plan to work more closely with this format. We tried to make a flexible yet simple dataset format that would cover most type of features and specificities present in reinforcement learning and robotics, in simulation and in real-world, with a focus on cameras and robot states but easily extended to other types of sensory inputs as long as they can be represented by a tensor. @@ -256,12 +271,12 @@ Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work ### Evaluate a pretrained policy -Check out [example 2](./examples/2_evaluate_pretrained_policy.py) that illustrates how to download a pretrained policy from Hugging Face hub, and run an evaluation on its corresponding environment. +Check out [example 2](https://github.com/huggingface/lerobot/blob/main/examples/2_evaluate_pretrained_policy.py) that illustrates how to download a pretrained policy from Hugging Face hub, and run an evaluation on its corresponding environment. We also provide a more capable script to parallelize the evaluation over multiple environments during the same rollout. Here is an example with a pretrained model hosted on [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht): ```bash -python -m lerobot.scripts.eval \ +lerobot-eval \ --policy.path=lerobot/diffusion_pusht \ --env.type=pusht \ --eval.batch_size=10 \ @@ -273,22 +288,22 @@ python -m lerobot.scripts.eval \ Note: After training your own policy, you can re-evaluate the checkpoints with: ```bash -python -m lerobot.scripts.eval --policy.path={OUTPUT_DIR}/checkpoints/last/pretrained_model +lerobot-eval --policy.path={OUTPUT_DIR}/checkpoints/last/pretrained_model ``` -See `python -m lerobot.scripts.eval --help` for more instructions. +See `lerobot-eval --help` for more instructions. ### Train your own policy -Check out [example 3](./examples/3_train_policy.py) that illustrates how to train a model using our core library in python, and [example 4](./examples/4_train_policy_with_script.md) that shows how to use our training script from command line. +Check out [example 3](https://github.com/huggingface/lerobot/blob/main/examples/3_train_policy.py) that illustrates how to train a model using our core library in python, and [example 4](https://github.com/huggingface/lerobot/blob/main/examples/4_train_policy_with_script.md) that shows how to use our training script from command line. To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding `--wandb.enable=true`. -A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](./examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explanation of some commonly used metrics in logs. +A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](https://github.com/huggingface/lerobot/blob/main/examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explanation of some commonly used metrics in logs. -![](media/wandb.png) +\WandB logs example -Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `--eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python -m lerobot.scripts.eval --help` for more instructions. +Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `--eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `lerobot-eval --help` for more instructions. #### Reproduce state-of-the-art (SOTA) @@ -296,7 +311,7 @@ We provide some pretrained policies on our [hub page](https://huggingface.co/ler You can reproduce their training by loading the config from their run. Simply running: ```bash -python -m lerobot.scripts.train --config_path=lerobot/diffusion_pusht +lerobot-train --config_path=lerobot/diffusion_pusht ``` reproduces SOTA results for Diffusion Policy on the PushT task. @@ -305,26 +320,6 @@ reproduces SOTA results for Diffusion Policy on the PushT task. If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md). - - ### Add a pretrained policy Once you have trained a policy you may upload it to the Hugging Face hub using a hub id that looks like `${hf_user}/${repo_name}` (e.g. [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)). @@ -341,34 +336,16 @@ To upload these to the hub, run the following: huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model ``` -See [eval.py](https://github.com/huggingface/lerobot/blob/main/lerobot/scripts/eval.py) for an example of how other people may use your policy. +See [eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/eval.py) for an example of how other people may use your policy. -### Improve your code with profiling +### Acknowledgment -An example of a code snippet to profile the evaluation of a policy: - - -```python -from torch.profiler import profile, record_function, ProfilerActivity - -def trace_handler(prof): - prof.export_chrome_trace(f"tmp/trace_schedule_{prof.step_num}.json") - -with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - schedule=torch.profiler.schedule( - wait=2, - warmup=2, - active=3, - ), - on_trace_ready=trace_handler -) as prof: - with record_function("eval_policy"): - for i in range(num_episodes): - prof.step() - # insert code to profile, potentially whole body of eval_policy function -``` - +- The LeRobot team 🤗 for building SmolVLA [Paper](https://arxiv.org/abs/2506.01844), [Blog](https://huggingface.co/blog/smolvla). +- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io). +- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io). +- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM). +- Thanks to Antonio Loquercio and Ashish Kumar for their early support. +- Thanks to [Seungjae (Jay) Lee](https://sjlee.cc/), [Mahi Shafiullah](https://mahis.life/) and colleagues for open sourcing [VQ-BeT](https://sjlee.cc/vq-bet/) policy and helping us adapt the codebase to our repository. The policy is adapted from [VQ-BeT repo](https://github.com/jayLEE0301/vq_bet_official). ## Citation @@ -376,83 +353,13 @@ If you want, you can cite this work with: ```bibtex @misc{cadene2024lerobot, - author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascale, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas}, + author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas}, title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch}, howpublished = "\url{https://github.com/huggingface/lerobot}", year = {2024} } ``` -Additionally, if you are using any of the particular policy architecture, pretrained models, or datasets, it is recommended to cite the original authors of the work as they appear below: - -- [SmolVLA](https://arxiv.org/abs/2506.01844) - -```bibtex -@article{shukor2025smolvla, - title={SmolVLA: A Vision-Language-Action Model for Affordable and Efficient Robotics}, - author={Shukor, Mustafa and Aubakirova, Dana and Capuano, Francesco and Kooijmans, Pepijn and Palma, Steven and Zouitine, Adil and Aractingi, Michel and Pascal, Caroline and Russi, Martino and Marafioti, Andres and Alibert, Simon and Cord, Matthieu and Wolf, Thomas and Cadene, Remi}, - journal={arXiv preprint arXiv:2506.01844}, - year={2025} -} -``` - -- [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) - -```bibtex -@article{chi2024diffusionpolicy, - author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song}, - title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion}, - journal = {The International Journal of Robotics Research}, - year = {2024}, -} -``` - -- [ACT or ALOHA](https://tonyzhaozh.github.io/aloha) - -```bibtex -@article{zhao2023learning, - title={Learning fine-grained bimanual manipulation with low-cost hardware}, - author={Zhao, Tony Z and Kumar, Vikash and Levine, Sergey and Finn, Chelsea}, - journal={arXiv preprint arXiv:2304.13705}, - year={2023} -} -``` - -- [TDMPC](https://www.nicklashansen.com/td-mpc/) - -```bibtex -@inproceedings{Hansen2022tdmpc, - title={Temporal Difference Learning for Model Predictive Control}, - author={Nicklas Hansen and Xiaolong Wang and Hao Su}, - booktitle={ICML}, - year={2022} -} -``` - -- [VQ-BeT](https://sjlee.cc/vq-bet/) - -```bibtex -@article{lee2024behavior, - title={Behavior generation with latent actions}, - author={Lee, Seungjae and Wang, Yibin and Etukuru, Haritheja and Kim, H Jin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel}, - journal={arXiv preprint arXiv:2403.03181}, - year={2024} -} -``` - -- [HIL-SERL](https://hil-serl.github.io/) - -```bibtex -@Article{luo2024hilserl, -title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning}, -author={Jianlan Luo and Charles Xu and Jeffrey Wu and Sergey Levine}, -year={2024}, -eprint={2410.21845}, -archivePrefix={arXiv}, -primaryClass={cs.RO} -} -``` - ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=huggingface/lerobot&type=Timeline)](https://star-history.com/#huggingface/lerobot&Timeline) diff --git a/docker/Dockerfile.user b/docker/Dockerfile.user index 4cfbb437a..bcd067637 100644 --- a/docker/Dockerfile.user +++ b/docker/Dockerfile.user @@ -29,7 +29,7 @@ ENV DEBIAN_FRONTEND=noninteractive \ # Install system dependencies and uv (as root) RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential git curl libglib2.0-0 libegl1-mesa ffmpeg \ + build-essential git curl libglib2.0-0 libegl1-mesa-dev ffmpeg \ libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \ && curl -LsSf https://astral.sh/uv/install.sh | sh \ && mv /root/.local/bin/uv /usr/local/bin/uv \ diff --git a/docs-requirements.txt b/docs-requirements.txt new file mode 100644 index 000000000..e286ad2bb --- /dev/null +++ b/docs-requirements.txt @@ -0,0 +1,3 @@ +# docs-requirements.txt +hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main +watchdog>=6.0.0 diff --git a/docs/README.md b/docs/README.md index 967de7b84..476eb8dce 100644 --- a/docs/README.md +++ b/docs/README.md @@ -20,7 +20,7 @@ To generate the documentation, you first have to build it. Several packages are you can install them with the following command, at the root of the code repository: ```bash -pip install -e ".[docs]" +pip install -e . -r docs-requirements.txt ``` You will also need `nodejs`. Please refer to their [installation page](https://nodejs.org/en/download) diff --git a/docs/source/cameras.mdx b/docs/source/cameras.mdx index 604863d74..5c35be0ba 100644 --- a/docs/source/cameras.mdx +++ b/docs/source/cameras.mdx @@ -9,7 +9,7 @@ To instantiate a camera, you need a camera identifier. This identifier might cha To find the camera indices of the cameras plugged into your system, run the following script: ```bash -python -m lerobot.find_cameras opencv # or realsense for Intel Realsense cameras +lerobot-find-cameras opencv # or realsense for Intel Realsense cameras ``` The output will look something like this if you have two cameras connected: diff --git a/docs/source/hilserl.mdx b/docs/source/hilserl.mdx index f66d8cab7..f8a5c69b2 100644 --- a/docs/source/hilserl.mdx +++ b/docs/source/hilserl.mdx @@ -28,7 +28,7 @@ This guide provides step-by-step instructions for training a robot policy using - A gamepad (recommended) or keyboard to control the robot - A Nvidia GPU - A real robot with a follower and leader arm (optional if you use the keyboard or the gamepad) -- A URDF file for the robot for the kinematics package (check `lerobot/common/model/kinematics.py`) +- A URDF file for the robot for the kinematics package (check `lerobot/model/kinematics.py`) ## What kind of tasks can I train? @@ -412,7 +412,7 @@ Example configuration for training the [reward classifier](https://huggingface.c To train the classifier, use the `train.py` script with your configuration: ```bash -python -m lerobot.scripts.train --config_path path/to/reward_classifier_train_config.json +lerobot-train --config_path path/to/reward_classifier_train_config.json ``` **Deploying and Testing the Model** @@ -458,7 +458,7 @@ The reward classifier will automatically provide rewards based on the visual inp 3. **Train the classifier**: ```bash - python -m lerobot.scripts.train --config_path src/lerobot/configs/reward_classifier_train_config.json + lerobot-train --config_path src/lerobot/configs/reward_classifier_train_config.json ``` 4. **Test the classifier**: @@ -477,7 +477,7 @@ Create a training configuration file (example available [here](https://huggingfa 1. Configure the policy settings (`type="sac"`, `device`, etc.) 2. Set `dataset` to your cropped dataset 3. Configure environment settings with crop parameters -4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/19bb621a7d0a31c20cd3cc08b1dbab68d3031454/lerobot/common/policies/sac/configuration_sac.py#L79). +4. Check the other parameters related to SAC in [configuration_sac.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/sac/configuration_sac.py#L79). 5. Verify that the `policy` config is correct with the right `input_features` and `output_features` for your task. **Starting the Learner** diff --git a/docs/source/hope_jr.mdx b/docs/source/hope_jr.mdx index 72aa8f923..856febb95 100644 --- a/docs/source/hope_jr.mdx +++ b/docs/source/hope_jr.mdx @@ -19,7 +19,7 @@ pip install -e ".[hopejr]" Before starting calibration and operation, you need to identify the USB ports for each HopeJR component. Run this script to find the USB ports for the arm, hand, glove, and exoskeleton: ```bash -python -m lerobot.find_port +lerobot-find-port ``` This will display the available USB ports and their associated devices. Make note of the port paths (e.g., `/dev/tty.usbmodem58760433331`, `/dev/tty.usbmodem11301`) as you'll need to specify them in the `--robot.port` and `--teleop.port` parameters when recording data, replaying episodes, or running teleoperation scripts. @@ -31,7 +31,7 @@ Before performing teleoperation, HopeJR's limbs need to be calibrated. Calibrati ### 1.1 Calibrate Robot Hand ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --robot.type=hope_jr_hand \ --robot.port=/dev/tty.usbmodem58760432281 \ --robot.id=blue \ @@ -81,7 +81,7 @@ Once you have set the appropriate boundaries for all joints, click "Save" to sav ### 1.2 Calibrate Teleoperator Glove ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --teleop.type=homunculus_glove \ --teleop.port=/dev/tty.usbmodem11201 \ --teleop.id=red \ @@ -120,7 +120,7 @@ Once calibration is complete, the system will save the calibration to `/Users/yo ### 1.3 Calibrate Robot Arm ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --robot.type=hope_jr_arm \ --robot.port=/dev/tty.usbserial-1110 \ --robot.id=white @@ -146,7 +146,7 @@ Use the calibration interface to set the range boundaries for each joint. Move e ### 1.4 Calibrate Teleoperator Exoskeleton ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --teleop.type=homunculus_arm \ --teleop.port=/dev/tty.usbmodem11201 \ --teleop.id=black @@ -178,7 +178,7 @@ Due to global variable conflicts in the Feetech middleware, teleoperation for ar ### Hand ```bash -python -m lerobot.teleoperate \ +lerobot-teleoperate \ --robot.type=hope_jr_hand \ --robot.port=/dev/tty.usbmodem58760432281 \ --robot.id=blue \ @@ -194,7 +194,7 @@ python -m lerobot.teleoperate \ ### Arm ```bash -python -m lerobot.teleoperate \ +lerobot-teleoperate \ --robot.type=hope_jr_arm \ --robot.port=/dev/tty.usbserial-1110 \ --robot.id=white \ @@ -214,7 +214,7 @@ Record, Replay and Train with Hope-JR is still experimental. This step records the dataset, which can be seen as an example [here](https://huggingface.co/datasets/nepyope/hand_record_test_with_video_data/settings). ```bash -python -m lerobot.record \ +lerobot-record \ --robot.type=hope_jr_hand \ --robot.port=/dev/tty.usbmodem58760432281 \ --robot.id=right \ @@ -236,7 +236,7 @@ python -m lerobot.record \ ### Replay ```bash -python -m lerobot.replay \ +lerobot-replay \ --robot.type=hope_jr_hand \ --robot.port=/dev/tty.usbmodem58760432281 \ --robot.id=right \ @@ -248,7 +248,7 @@ python -m lerobot.replay \ ### Train ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --dataset.repo_id=nepyope/hand_record_test_with_video_data \ --policy.type=act \ --output_dir=outputs/train/hopejr_hand \ @@ -263,7 +263,7 @@ python -m lerobot.scripts.train \ This training run can be viewed as an example [here](https://wandb.ai/tino/lerobot/runs/rp0k8zvw?nw=nwusertino). ```bash -python -m lerobot.record \ +lerobot-record \ --robot.type=hope_jr_hand \ --robot.port=/dev/tty.usbmodem58760432281 \ --robot.id=right \ diff --git a/docs/source/il_robots.mdx b/docs/source/il_robots.mdx index ccca6d508..905046bef 100644 --- a/docs/source/il_robots.mdx +++ b/docs/source/il_robots.mdx @@ -45,7 +45,7 @@ Note that the `id` associated with a robot is used to store the calibration file ```bash -python -m lerobot.teleoperate \ +lerobot-teleoperate \ --robot.type=so101_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.id=my_awesome_follower_arm \ @@ -101,7 +101,7 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam ```bash -python -m lerobot.teleoperate \ +lerobot-teleoperate \ --robot.type=koch_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.id=my_awesome_follower_arm \ @@ -174,7 +174,7 @@ Now you can record a dataset. To record 5 episodes and upload your dataset to th ```bash -python -m lerobot.record \ +lerobot-record \ --robot.type=so101_follower \ --robot.port=/dev/tty.usbmodem585A0076841 \ --robot.id=my_awesome_follower_arm \ @@ -294,7 +294,7 @@ dataset.push_to_hub() #### Dataset upload -Locally, your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}`. At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/so101_test) that you can obtain by running: +Locally, your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}`. At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. `https://huggingface.co/datasets/${HF_USER}/so101_test`) that you can obtain by running: ```bash echo https://huggingface.co/datasets/${HF_USER}/so101_test @@ -323,7 +323,7 @@ The `record` function provides a suite of tools for capturing and managing data ##### 2. Checkpointing and Resuming - Checkpoints are automatically created during recording. -- If an issue occurs, you can resume by re-running the same command with `--resume=true`. +- If an issue occurs, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset ! - To start recording from scratch, **manually delete** the dataset directory. ##### 3. Recording Parameters @@ -376,7 +376,7 @@ You can replay the first episode on your robot with either the command below or ```bash -python -m lerobot.replay \ +lerobot-replay \ --robot.type=so101_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.id=my_awesome_follower_arm \ @@ -428,10 +428,10 @@ Your robot should replicate movements similar to those you recorded. For example ## Train a policy -To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: +To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --dataset.repo_id=${HF_USER}/so101_test \ --policy.type=act \ --output_dir=outputs/train/act_so101_test \ @@ -444,7 +444,7 @@ python -m lerobot.scripts.train \ Let's explain the command: 1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so101_test`. -2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. +2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. 3. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon. 4. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. @@ -453,7 +453,7 @@ Training should take several hours. You will find checkpoints in `outputs/train/ To resume training from a checkpoint, below is an example command to resume from `last` checkpoint of the `act_so101_test` policy: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --config_path=outputs/train/act_so101_test/checkpoints/last/pretrained_model/train_config.json \ --resume=true ``` @@ -462,9 +462,9 @@ If you do not want to push your model to the hub after training use `--policy.pu Additionally you can provide extra `tags` or specify a `license` for your model or make the model repo `private` by adding this: `--policy.private=true --policy.tags=\[ppo,rl\] --policy.license=mit` -#### Train using Collab +#### Train using Google Colab -If your local computer doesn't have a powerful GPU you could utilize Google Collab to train your model by following the [ACT training notebook](./notebooks#training-act). +If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act). #### Upload policy checkpoints @@ -490,7 +490,7 @@ You can use the `record` script from [`lerobot/record.py`](https://github.com/hu ```bash -python -m lerobot.record \ +lerobot-record \ --robot.type=so100_follower \ --robot.port=/dev/ttyACM1 \ --robot.cameras="{ up: {type: opencv, index_or_path: /dev/video10, width: 640, height: 480, fps: 30}, side: {type: intelrealsense, serial_number_or_name: 233522074606, width: 640, height: 480, fps: 30}}" \ diff --git a/docs/source/il_sim.mdx b/docs/source/il_sim.mdx index 193b09b1b..3dd80dc4b 100644 --- a/docs/source/il_sim.mdx +++ b/docs/source/il_sim.mdx @@ -96,10 +96,10 @@ If you uploaded your dataset to the hub you can [visualize your dataset online]( ## Train a policy -To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: +To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --dataset.repo_id=${HF_USER}/il_gym \ --policy.type=act \ --output_dir=outputs/train/il_sim_test \ @@ -111,7 +111,7 @@ python -m lerobot.scripts.train \ Let's explain the command: 1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/il_gym`. -2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. +2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. 3. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon. 4. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx index 13c3600b4..93354c2ee 100644 --- a/docs/source/installation.mdx +++ b/docs/source/installation.mdx @@ -1,15 +1,6 @@ # Installation -## Install LeRobot - -Currently only available from source. - -Download our source code: - -```bash -git clone https://github.com/huggingface/lerobot.git -cd lerobot -``` +## Environment Setup Create a virtual environment with Python 3.10, using [`Miniconda`](https://docs.anaconda.com/miniconda/install/#quick-command-line-install) @@ -40,12 +31,49 @@ conda install ffmpeg -c conda-forge > > - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`. -Install 🤗 LeRobot: +## Install LeRobot 🤗 + +### From Source + +First, clone the repository and navigate into the directory: + +```bash +git clone https://github.com/huggingface/lerobot.git +cd lerobot +``` + +Then, install the library in editable mode. This is useful if you plan to contribute to the code. ```bash pip install -e . ``` +### Installation from PyPI + +**Core Library:** +Install the base package with: + +```bash +pip install lerobot +``` + +_This installs only the default dependencies._ + +**Extra Features:** +To install additional functionality, use one of the following: + +```bash +pip install 'lerobot[all]' # All available features +pip install 'lerobot[aloha,pusht]' # Specific features (Aloha & Pusht) +pip install 'lerobot[feetech]' # Feetech motor support +``` + +_Replace `[...]` with your desired features._ + +**Available Tags:** +For a full list of optional dependencies, see: +https://pypi.org/project/lerobot/ + ### Troubleshooting If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`. diff --git a/docs/source/koch.mdx b/docs/source/koch.mdx index d0b991e74..3e94899a8 100644 --- a/docs/source/koch.mdx +++ b/docs/source/koch.mdx @@ -31,7 +31,7 @@ pip install -e ".[dynamixel]" To find the port for each bus servo adapter, run this script: ```bash -python -m lerobot.find_port +lerobot-find-port ``` @@ -98,7 +98,7 @@ For a visual reference on how to set the motor ids please refer to [this video]( ```bash -python -m lerobot.setup_motors \ +lerobot-setup-motors \ --robot.type=koch_follower \ --robot.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step ``` @@ -174,7 +174,7 @@ Do the same steps for the leader arm but modify the command or script accordingl ```bash -python -m lerobot.setup_motors \ +lerobot-setup-motors \ --teleop.type=koch_leader \ --teleop.port=/dev/tty.usbmodem575E0031751 \ # <- paste here the port found at previous step ``` @@ -211,7 +211,7 @@ Run the following command or API example to calibrate the follower arm: ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --robot.type=koch_follower \ --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot --robot.id=my_awesome_follower_arm # <- Give the robot a unique name @@ -249,7 +249,7 @@ Do the same steps to calibrate the leader arm, run the following command or API ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --teleop.type=koch_leader \ --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name diff --git a/docs/source/lekiwi.mdx b/docs/source/lekiwi.mdx index bb70fd26b..14c06e444 100644 --- a/docs/source/lekiwi.mdx +++ b/docs/source/lekiwi.mdx @@ -60,7 +60,7 @@ First, we will assemble the two SO100/SO101 arms. One to attach to the mobile ba To find the port for each bus servo adapter, run this script: ```bash -python -m lerobot.find_port +lerobot-find-port ``` @@ -116,7 +116,7 @@ The instructions for configuring the motors can be found in the SO101 [docs](./s You can run this command to setup motors for LeKiwi. It will first setup the motors for arm (id 6..1) and then setup motors for wheels (9,8,7) ```bash -python -m lerobot.setup_motors \ +lerobot-setup-motors \ --robot.type=lekiwi \ --robot.port=/dev/tty.usbmodem58760431551 # <- paste here the port found at previous step ``` @@ -174,7 +174,7 @@ The calibration process is very important because it allows a neural network tra Make sure the arm is connected to the Raspberry Pi and run this script or API example (on the Raspberry Pi via SSH) to launch calibration of the follower arm: ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --robot.type=lekiwi \ --robot.id=my_awesome_kiwi # <- Give the robot a unique name ``` @@ -193,7 +193,7 @@ Then, to calibrate the leader arm (which is attached to the laptop/pc). Run the ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --teleop.type=so100_leader \ --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name @@ -258,7 +258,7 @@ You should see on your laptop something like this: `[INFO] Connected to remote r | F | Decrease speed | > [!TIP] -> If you use a different keyboard, you can change the keys for each command in the [`LeKiwiConfig`](../src/lerobot/robot_devices/robots/configs.py). +> If you use a different keyboard, you can change the keys for each command in the [`LeKiwiClientConfig`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/robots/lekiwi/config_lekiwi.py). ### Wired version diff --git a/docs/source/policy_act_README.md b/docs/source/policy_act_README.md new file mode 100644 index 000000000..371a9136f --- /dev/null +++ b/docs/source/policy_act_README.md @@ -0,0 +1,14 @@ +## Paper + +https://tonyzhaozh.github.io/aloha + +## Citation + +```bibtex +@article{zhao2023learning, + title={Learning fine-grained bimanual manipulation with low-cost hardware}, + author={Zhao, Tony Z and Kumar, Vikash and Levine, Sergey and Finn, Chelsea}, + journal={arXiv preprint arXiv:2304.13705}, + year={2023} +} +``` diff --git a/docs/source/policy_diffusion_README.md b/docs/source/policy_diffusion_README.md new file mode 100644 index 000000000..9ec934add --- /dev/null +++ b/docs/source/policy_diffusion_README.md @@ -0,0 +1,14 @@ +## Paper + +https://diffusion-policy.cs.columbia.edu + +## Citation + +```bibtex +@article{chi2024diffusionpolicy, + author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song}, + title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion}, + journal = {The International Journal of Robotics Research}, + year = {2024}, +} +``` diff --git a/docs/source/policy_smolvla_README.md b/docs/source/policy_smolvla_README.md new file mode 100644 index 000000000..ee567ee83 --- /dev/null +++ b/docs/source/policy_smolvla_README.md @@ -0,0 +1,14 @@ +## Paper + +https://arxiv.org/abs/2506.01844 + +## Citation + +```bibtex +@article{shukor2025smolvla, + title={SmolVLA: A Vision-Language-Action Model for Affordable and Efficient Robotics}, + author={Shukor, Mustafa and Aubakirova, Dana and Capuano, Francesco and Kooijmans, Pepijn and Palma, Steven and Zouitine, Adil and Aractingi, Michel and Pascal, Caroline and Russi, Martino and Marafioti, Andres and Alibert, Simon and Cord, Matthieu and Wolf, Thomas and Cadene, Remi}, + journal={arXiv preprint arXiv:2506.01844}, + year={2025} +} +``` diff --git a/docs/source/policy_tdmpc_README.md b/docs/source/policy_tdmpc_README.md new file mode 100644 index 000000000..804f166c8 --- /dev/null +++ b/docs/source/policy_tdmpc_README.md @@ -0,0 +1,14 @@ +## Paper + +https://www.nicklashansen.com/td-mpc/ + +## Citation + +```bibtex +@inproceedings{Hansen2022tdmpc, + title={Temporal Difference Learning for Model Predictive Control}, + author={Nicklas Hansen and Xiaolong Wang and Hao Su}, + booktitle={ICML}, + year={2022} +} +``` diff --git a/docs/source/policy_vqbet_README.md b/docs/source/policy_vqbet_README.md new file mode 100644 index 000000000..02f95b7c2 --- /dev/null +++ b/docs/source/policy_vqbet_README.md @@ -0,0 +1,14 @@ +## Paper + +https://sjlee.cc/vq-bet/ + +## Citation + +```bibtex +@article{lee2024behavior, + title={Behavior generation with latent actions}, + author={Lee, Seungjae and Wang, Yibin and Etukuru, Haritheja and Kim, H Jin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel}, + journal={arXiv preprint arXiv:2403.03181}, + year={2024} +} +``` diff --git a/docs/source/smolvla.mdx b/docs/source/smolvla.mdx index 880beaa1a..89c475a90 100644 --- a/docs/source/smolvla.mdx +++ b/docs/source/smolvla.mdx @@ -54,7 +54,7 @@ If you don't have a gpu device, you can train using our notebook on [![Google Co Pass your dataset to the training script using `--dataset.repo_id`. If you want to test your installation, run the following command where we use one of the datasets we collected for the [SmolVLA Paper](https://huggingface.co/papers/2506.01844). ```bash -cd lerobot && python -m lerobot.scripts.train \ +cd lerobot && lerobot-train \ --policy.path=lerobot/smolvla_base \ --dataset.repo_id=${HF_USER}/mydataset \ --batch_size=64 \ @@ -73,7 +73,7 @@ cd lerobot && python -m lerobot.scripts.train \ Fine-tuning is an art. For a complete overview of the options for finetuning, run ```bash -python -m lerobot.scripts.train --help +lerobot-train --help ```

@@ -97,7 +97,7 @@ Similarly for when recording an episode, it is recommended that you are logged i Once you are logged in, you can run inference in your setup by doing: ```bash -python -m lerobot.record \ +lerobot-record \ --robot.type=so101_follower \ --robot.port=/dev/ttyACM0 \ # <- Use your port --robot.id=my_blue_follower_arm \ # <- Use your robot id diff --git a/docs/source/so100.mdx b/docs/source/so100.mdx index d9ff922c5..8578e1e8d 100644 --- a/docs/source/so100.mdx +++ b/docs/source/so100.mdx @@ -26,7 +26,7 @@ Unlike the SO-101, the motor connectors are not easily accessible once the arm i To find the port for each bus servo adapter, run this script: ```bash -python -m lerobot.find_port +lerobot-find-port ``` @@ -93,7 +93,7 @@ For a visual reference on how to set the motor ids please refer to [this video]( ```bash -python -m lerobot.setup_motors \ +lerobot-setup-motors \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step ``` @@ -168,7 +168,7 @@ Do the same steps for the leader arm. ```bash -python -m lerobot.setup_motors \ +lerobot-setup-motors \ --teleop.type=so100_leader \ --teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step ``` @@ -568,7 +568,7 @@ Run the following command or API example to calibrate the follower arm: ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot --robot.id=my_awesome_follower_arm # <- Give the robot a unique name @@ -606,7 +606,7 @@ Do the same steps to calibrate the leader arm, run the following command or API ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --teleop.type=so100_leader \ --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name diff --git a/docs/source/so101.mdx b/docs/source/so101.mdx index a20a3fa9f..b9fb9cab4 100644 --- a/docs/source/so101.mdx +++ b/docs/source/so101.mdx @@ -162,7 +162,7 @@ It is advisable to install one 3-pin cable in the motor after placing them befor To find the port for each bus servo adapter, connect MotorBus to your computer via USB and power. Run the following script and disconnect the MotorBus when prompted: ```bash -python -m lerobot.find_port +lerobot-find-port ``` @@ -240,7 +240,7 @@ Connect the usb cable from your computer and the power supply to the follower ar ```bash -python -m lerobot.setup_motors \ +lerobot-setup-motors \ --robot.type=so101_follower \ --robot.port=/dev/tty.usbmodem585A0076841 # <- paste here the port found at previous step ``` @@ -316,7 +316,7 @@ Do the same steps for the leader arm. ```bash -python -m lerobot.setup_motors \ +lerobot-setup-motors \ --teleop.type=so101_leader \ --teleop.port=/dev/tty.usbmodem575E0031751 # <- paste here the port found at previous step ``` @@ -353,7 +353,7 @@ Run the following command or API example to calibrate the follower arm: ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --robot.type=so101_follower \ --robot.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot --robot.id=my_awesome_follower_arm # <- Give the robot a unique name @@ -402,7 +402,7 @@ Do the same steps to calibrate the leader arm, run the following command or API ```bash -python -m lerobot.calibrate \ +lerobot-calibrate \ --teleop.type=so101_leader \ --teleop.port=/dev/tty.usbmodem58760431551 \ # <- The port of your robot --teleop.id=my_awesome_leader_arm # <- Give the robot a unique name diff --git a/examples/4_train_policy_with_script.md b/examples/4_train_policy_with_script.md index d6cd6cc23..ffa7de66e 100644 --- a/examples/4_train_policy_with_script.md +++ b/examples/4_train_policy_with_script.md @@ -62,7 +62,7 @@ By default, every field takes its default value specified in the dataclass. If a Let's say that we want to train [Diffusion Policy](../src/lerobot/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --dataset.repo_id=lerobot/pusht \ --policy.type=diffusion \ --env.type=pusht @@ -77,7 +77,7 @@ Let's break this down: Let's see another example. Let's say you've been training [ACT](../src/lerobot/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.type=act \ --dataset.repo_id=lerobot/aloha_sim_insertion_human \ --env.type=aloha \ @@ -90,7 +90,7 @@ We now want to train a different policy for aloha on another task. We'll change Looking at the [`AlohaEnv`](../src/lerobot/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.type=act \ --dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \ --env.type=aloha \ @@ -127,7 +127,7 @@ Now, let's assume that we want to reproduce the run just above. That run has pro We can then simply load the config values from this file using: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \ --output_dir=outputs/train/act_aloha_transfer_2 ``` @@ -137,7 +137,7 @@ python -m lerobot.scripts.train \ Similarly to Hydra, we can still override some parameters in the CLI if we want to, e.g.: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \ --output_dir=outputs/train/act_aloha_transfer_2 --policy.n_action_steps=80 @@ -148,7 +148,7 @@ python -m lerobot.scripts.train \ `--config_path` can also accept the repo_id of a repo on the hub that contains a `train_config.json` file, e.g. running: ```bash -python -m lerobot.scripts.train --config_path=lerobot/diffusion_pusht +lerobot-train --config_path=lerobot/diffusion_pusht ``` will start a training run with the same configuration used for training [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht) @@ -160,7 +160,7 @@ Being able to resume a training run is important in case it crashed or aborted f Let's reuse the command from the previous run and add a few more options: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.type=act \ --dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \ --env.type=aloha \ @@ -179,7 +179,7 @@ INFO 2025-01-24 16:10:56 ts/train.py:263 Checkpoint policy after step 100 Now let's simulate a crash by killing the process (hit `ctrl`+`c`). We can then simply resume this run from the last checkpoint available with: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \ --resume=true ``` @@ -190,7 +190,7 @@ Another reason for which you might want to resume a run is simply to extend trai You could double the number of steps of the previous run with: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \ --resume=true \ --steps=200000 @@ -224,7 +224,7 @@ In addition to the features currently in Draccus, we've added a special `.path` For example, we could fine-tune a [policy pre-trained on the aloha transfer task](https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human) on the aloha insertion task. We can achieve this with: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.path=lerobot/act_aloha_sim_transfer_cube_human \ --dataset.repo_id=lerobot/aloha_sim_insertion_human \ --env.type=aloha \ @@ -270,7 +270,7 @@ We'll summarize here the main use cases to remember from this tutorial. #### Train a policy from scratch – CLI ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.type=act \ # <- select 'act' policy --env.type=pusht \ # <- select 'pusht' environment --dataset.repo_id=lerobot/pusht # <- train on this dataset @@ -279,7 +279,7 @@ python -m lerobot.scripts.train \ #### Train a policy from scratch - config file + CLI ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --config_path=path/to/pretrained_model \ # <- can also be a repo_id --policy.n_action_steps=80 # <- you may still override values ``` @@ -287,7 +287,7 @@ python -m lerobot.scripts.train \ #### Resume/continue a training run ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --config_path=checkpoint/pretrained_model/ \ --resume=true \ --steps=200000 # <- you can change some training parameters @@ -296,7 +296,7 @@ python -m lerobot.scripts.train \ #### Fine-tuning ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.path=lerobot/act_aloha_sim_transfer_cube_human \ # <- can also be a local path to a checkpoint --dataset.repo_id=lerobot/aloha_sim_insertion_human \ --env.type=aloha \ diff --git a/examples/backward_compatibility/replay.py b/examples/backward_compatibility/replay.py index cc3397543..6c680f204 100644 --- a/examples/backward_compatibility/replay.py +++ b/examples/backward_compatibility/replay.py @@ -18,7 +18,7 @@ Replays the actions of an episode from a dataset on a robot. Example: ```shell -python -m lerobot.replay \ +lerobot-replay \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.id=black \ diff --git a/pyproject.toml b/pyproject.toml index 9cf8b113a..d1eb3949c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb" [project] name = "lerobot" -version = "0.2.0" +version = "0.3.4" description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch" readme = "README.md" license = { text = "Apache-2.0" } @@ -61,21 +61,22 @@ dependencies = [ # Hugging Face dependencies "datasets>=2.19.0,<=3.6.0", # TODO: Bumb dependency "diffusers>=0.27.2", - "huggingface-hub[hf-transfer,cli]>=0.27.1", + "huggingface-hub[hf-transfer,cli]>=0.34.2", # Core dependencies "cmake>=3.29.0.1", "einops>=0.8.0", "opencv-python-headless>=4.9.0", "av>=14.2.0", - "torch>=2.2.1", - "torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", - "torchvision>=0.21.0", "jsonlines>=4.0.0", "packaging>=24.2", "pynput>=1.7.7", "pyserial>=3.5", - "wandb>=0.16.3", + "wandb>=0.20.0", + + "torch>=2.2.1,<2.8.0", # TODO: Bumb dependency + "torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency + "torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency "draccus==0.10.0", # TODO: Remove == "gymnasium>=0.29.1,<1.0.0", # TODO: Bumb dependency @@ -126,7 +127,6 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.9", "lerobot[grpcio-dep]", async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"] # Development -docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"] dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"] test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "mock-serial>=0.0.1 ; sys_platform != 'win32'"] video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"] @@ -149,7 +149,6 @@ all = [ "lerobot[smolvla]", "lerobot[hilserl]", "lerobot[async]", - "lerobot[docs]", "lerobot[dev]", "lerobot[test]", "lerobot[video_benchmark]", diff --git a/requirements-macos.txt b/requirements-macos.txt new file mode 100644 index 000000000..07e263da5 --- /dev/null +++ b/requirements-macos.txt @@ -0,0 +1,625 @@ +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# pip-compile --output-file=requirements-macos.txt requirements.in +# +-e .[all] + # via -[all] +absl-py==2.3.1 + # via + # dm-control + # dm-env + # dm-tree + # labmaze + # mujoco +accelerate==1.9.0 + # via lerobot +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.12.15 + # via fsspec +aiosignal==1.4.0 + # via aiohttp +annotated-types==0.7.0 + # via pydantic +asttokens==3.0.0 + # via stack-data +async-timeout==5.0.1 + # via aiohttp +attrs==25.3.0 + # via + # aiohttp + # dm-tree + # jsonlines + # rerun-sdk +av==15.0.0 + # via lerobot +blinker==1.9.0 + # via flask +certifi==2025.7.14 + # via + # requests + # sentry-sdk +cffi==1.17.1 + # via pymunk +cfgv==3.4.0 + # via pre-commit +charset-normalizer==3.4.2 + # via requests +click==8.2.1 + # via + # flask + # wandb +cloudpickle==3.1.1 + # via gymnasium +cmake==4.0.3 + # via lerobot +cmeel==0.57.3 + # via + # cmeel-assimp + # cmeel-boost + # cmeel-console-bridge + # cmeel-octomap + # cmeel-qhull + # cmeel-tinyxml2 + # cmeel-urdfdom + # cmeel-zlib + # coal-library + # eigenpy + # eiquadprog + # pin + # placo + # rhoban-cmeel-jsoncpp +cmeel-assimp==5.4.3.1 + # via coal-library +cmeel-boost==1.87.0.1 + # via + # coal-library + # eigenpy + # eiquadprog + # pin +cmeel-console-bridge==1.0.2.3 + # via cmeel-urdfdom +cmeel-octomap==1.10.0 + # via coal-library +cmeel-qhull==8.0.2.1 + # via coal-library +cmeel-tinyxml2==10.0.0 + # via cmeel-urdfdom +cmeel-urdfdom==4.0.1 + # via pin +cmeel-zlib==1.3.1 + # via cmeel-assimp +coal-library==3.0.1 + # via pin +contourpy==1.3.2 + # via matplotlib +coverage[toml]==7.10.1 + # via pytest-cov +cycler==0.12.1 + # via matplotlib +datasets==3.6.0 + # via lerobot +debugpy==1.8.15 + # via lerobot +decorator==5.2.1 + # via ipython +deepdiff==8.5.0 + # via lerobot +diffusers==0.34.0 + # via lerobot +dill==0.3.8 + # via + # datasets + # multiprocess +distlib==0.4.0 + # via virtualenv +dm-control==1.0.14 + # via gym-aloha +dm-env==1.6 + # via dm-control +dm-tree==0.1.9 + # via + # dm-control + # dm-env +docopt==0.6.2 + # via num2words +draccus==0.10.0 + # via lerobot +dynamixel-sdk==3.7.31 + # via lerobot +eigenpy==3.10.3 + # via coal-library +einops==0.8.1 + # via lerobot +eiquadprog==1.2.9 + # via placo +exceptiongroup==1.3.0 + # via + # ipython + # pytest +executing==2.2.0 + # via stack-data +farama-notifications==0.0.4 + # via gymnasium +feetech-servo-sdk==1.0.0 + # via lerobot +filelock==3.18.0 + # via + # datasets + # diffusers + # huggingface-hub + # torch + # transformers + # virtualenv +flask==3.1.1 + # via lerobot +fonttools==4.59.0 + # via matplotlib +frozenlist==1.7.0 + # via + # aiohttp + # aiosignal +fsspec[http]==2025.3.0 + # via + # datasets + # huggingface-hub + # torch +gitdb==4.0.12 + # via gitpython +gitpython==3.1.45 + # via wandb +glfw==2.9.0 + # via + # dm-control + # mujoco +grpcio==1.73.1 + # via + # grpcio-tools + # lerobot +grpcio-tools==1.73.1 + # via lerobot +gym-aloha==0.1.1 + # via lerobot +gym-hil==0.1.10 + # via lerobot +gym-pusht==0.1.5 + # via lerobot +gym-xarm==0.1.1 + # via lerobot +gymnasium==0.29.1 + # via + # gym-aloha + # gym-hil + # gym-pusht + # gym-xarm + # gymnasium-robotics + # lerobot + # pettingzoo +gymnasium-robotics==1.2.4 + # via gym-xarm +hf-transfer==0.1.9 + # via huggingface-hub +hf-xet==1.1.5 + # via huggingface-hub +hidapi==0.14.0.post4 + # via + # gym-hil + # lerobot +huggingface-hub[cli,hf-transfer]==0.34.3 + # via + # accelerate + # datasets + # diffusers + # lerobot + # tokenizers + # transformers +identify==2.6.12 + # via pre-commit +idna==3.10 + # via + # requests + # yarl +imageio[ffmpeg]==2.37.0 + # via + # gym-aloha + # gym-hil + # gymnasium-robotics + # lerobot + # scikit-image +imageio-ffmpeg==0.6.0 + # via imageio +importlib-metadata==8.7.0 + # via diffusers +iniconfig==2.1.0 + # via pytest +inquirerpy==0.3.4 + # via huggingface-hub +ipython==8.37.0 + # via meshcat +ischedule==1.2.7 + # via placo +itsdangerous==2.2.0 + # via flask +jedi==0.19.2 + # via ipython +jinja2==3.1.6 + # via + # flask + # gymnasium-robotics + # torch +jsonlines==4.0.0 + # via lerobot +kiwisolver==1.4.8 + # via matplotlib +labmaze==1.0.6 + # via dm-control +lazy-loader==0.4 + # via scikit-image +lxml==6.0.0 + # via dm-control +markupsafe==3.0.2 + # via + # flask + # jinja2 + # werkzeug +matplotlib==3.10.5 + # via lerobot +matplotlib-inline==0.1.7 + # via ipython +mergedeep==1.3.4 + # via draccus +meshcat==0.3.2 + # via placo +mock-serial==0.0.1 + # via lerobot +mpmath==1.3.0 + # via sympy +mujoco==2.3.7 + # via + # dm-control + # gym-aloha + # gym-hil + # gym-xarm + # gymnasium-robotics +multidict==6.6.3 + # via + # aiohttp + # yarl +multiprocess==0.70.16 + # via datasets +mypy-extensions==1.1.0 + # via typing-inspect +networkx==3.4.2 + # via + # scikit-image + # torch +nodeenv==1.9.1 + # via pre-commit +num2words==0.5.14 + # via lerobot +numpy==2.2.6 + # via + # accelerate + # cmeel-boost + # contourpy + # datasets + # diffusers + # dm-control + # dm-env + # dm-tree + # gymnasium + # gymnasium-robotics + # imageio + # labmaze + # matplotlib + # meshcat + # mujoco + # opencv-python + # opencv-python-headless + # pandas + # pettingzoo + # rerun-sdk + # scikit-image + # scipy + # shapely + # tifffile + # torchvision + # transformers +opencv-python==4.12.0.88 + # via gym-pusht +opencv-python-headless==4.12.0.88 + # via lerobot +orderly-set==5.5.0 + # via deepdiff +packaging==25.0 + # via + # accelerate + # datasets + # huggingface-hub + # lazy-loader + # lerobot + # matplotlib + # pytest + # scikit-image + # transformers + # wandb +pandas==2.3.1 + # via + # datasets + # lerobot +parso==0.8.4 + # via jedi +pettingzoo==1.24.3 + # via gymnasium-robotics +pexpect==4.9.0 + # via ipython +pfzy==0.3.4 + # via inquirerpy +pillow==11.3.0 + # via + # diffusers + # imageio + # matplotlib + # meshcat + # rerun-sdk + # scikit-image + # torchvision +pin==3.4.0 + # via placo +placo==0.9.14 + # via lerobot +platformdirs==4.3.8 + # via + # virtualenv + # wandb +pluggy==1.6.0 + # via + # pytest + # pytest-cov +pre-commit==4.2.0 + # via lerobot +prompt-toolkit==3.0.51 + # via + # inquirerpy + # ipython +propcache==0.3.2 + # via + # aiohttp + # yarl +protobuf==6.31.0 + # via + # dm-control + # grpcio-tools + # lerobot + # wandb +psutil==7.0.0 + # via + # accelerate + # imageio +ptyprocess==0.7.0 + # via pexpect +pure-eval==0.2.3 + # via stack-data +pyarrow==21.0.0 + # via + # datasets + # rerun-sdk +pycparser==2.22 + # via cffi +pydantic==2.11.7 + # via wandb +pydantic-core==2.33.2 + # via pydantic +pygame==2.6.1 + # via + # gym-hil + # gym-pusht + # lerobot +pygments==2.19.2 + # via + # ipython + # pytest +pymunk==6.11.1 + # via + # gym-pusht + # lerobot +pyngrok==7.2.12 + # via meshcat +pynput==1.8.1 + # via + # gym-hil + # lerobot +pyobjc-core==11.1 + # via + # pyobjc-framework-applicationservices + # pyobjc-framework-cocoa + # pyobjc-framework-coretext + # pyobjc-framework-quartz +pyobjc-framework-applicationservices==11.1 + # via pynput +pyobjc-framework-cocoa==11.1 + # via + # pyobjc-framework-applicationservices + # pyobjc-framework-coretext + # pyobjc-framework-quartz +pyobjc-framework-coretext==11.1 + # via pyobjc-framework-applicationservices +pyobjc-framework-quartz==11.1 + # via + # pynput + # pyobjc-framework-applicationservices + # pyobjc-framework-coretext +pyopengl==3.1.9 + # via + # dm-control + # mujoco +pyparsing==3.2.3 + # via + # dm-control + # matplotlib +pyrealsense2-macosx==2.54.2 + # via lerobot +pyserial==3.5 + # via + # dynamixel-sdk + # feetech-servo-sdk + # lerobot +pytest==8.4.1 + # via + # lerobot + # pytest-cov + # pytest-timeout +pytest-cov==6.2.1 + # via lerobot +pytest-timeout==2.4.0 + # via lerobot +python-dateutil==2.9.0.post0 + # via + # matplotlib + # pandas +pytz==2025.2 + # via pandas +pyyaml==6.0.2 + # via + # accelerate + # datasets + # draccus + # huggingface-hub + # pre-commit + # pyngrok + # pyyaml-include + # transformers + # wandb +pyyaml-include==1.4.1 + # via draccus +pyzmq==27.0.0 + # via + # lerobot + # meshcat +regex==2025.7.34 + # via + # diffusers + # transformers +requests==2.32.4 + # via + # datasets + # diffusers + # dm-control + # huggingface-hub + # transformers + # wandb +rerun-sdk==0.22.1 + # via lerobot +rhoban-cmeel-jsoncpp==1.9.4.9 + # via placo +safetensors==0.5.3 + # via + # accelerate + # diffusers + # lerobot + # transformers +scikit-image==0.25.2 + # via + # gym-pusht + # lerobot +scipy==1.15.3 + # via + # dm-control + # scikit-image +sentry-sdk==2.34.1 + # via wandb +shapely==2.1.1 + # via gym-pusht +six==1.17.0 + # via + # pynput + # python-dateutil +smmap==5.0.2 + # via gitdb +stack-data==0.6.3 + # via ipython +sympy==1.14.0 + # via torch +termcolor==3.1.0 + # via lerobot +tifffile==2025.5.10 + # via scikit-image +tokenizers==0.21.4 + # via transformers +toml==0.10.2 + # via draccus +tomli==2.2.1 + # via + # cmeel + # coverage + # pytest +torch==2.7.1 + # via + # accelerate + # lerobot + # torchvision +torchcodec==0.5 + # via lerobot +torchvision==0.22.1 + # via lerobot +tornado==6.5.1 + # via meshcat +tqdm==4.67.1 + # via + # datasets + # dm-control + # huggingface-hub + # transformers +traitlets==5.14.3 + # via + # ipython + # matplotlib-inline +transformers==4.51.3 + # via lerobot +typing-extensions==4.14.1 + # via + # aiosignal + # exceptiongroup + # gymnasium + # huggingface-hub + # ipython + # multidict + # pydantic + # pydantic-core + # rerun-sdk + # torch + # typing-inspect + # typing-inspection + # wandb +typing-inspect==0.9.0 + # via draccus +typing-inspection==0.4.1 + # via pydantic +tzdata==2025.2 + # via pandas +u-msgpack-python==2.8.0 + # via meshcat +urllib3==2.5.0 + # via + # requests + # sentry-sdk +virtualenv==20.32.0 + # via pre-commit +wandb==0.21.0 + # via lerobot +wcwidth==0.2.13 + # via prompt-toolkit +werkzeug==3.1.3 + # via flask +wrapt==1.17.2 + # via dm-tree +xxhash==3.5.0 + # via datasets +yarl==1.20.1 + # via aiohttp +zipp==3.23.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/requirements-ubuntu.txt b/requirements-ubuntu.txt new file mode 100644 index 000000000..af7258d67 --- /dev/null +++ b/requirements-ubuntu.txt @@ -0,0 +1,650 @@ +# +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# pip-compile --output-file=requirements-ubuntu.txt requirements.in +# +-e .[all] + # via -[all] +absl-py==2.3.1 + # via + # dm-control + # dm-env + # dm-tree + # labmaze + # mujoco +accelerate==1.9.0 + # via lerobot +aiohappyeyeballs==2.6.1 + # via aiohttp +aiohttp==3.12.15 + # via fsspec +aiosignal==1.4.0 + # via aiohttp +annotated-types==0.7.0 + # via pydantic +asttokens==3.0.0 + # via stack-data +async-timeout==5.0.1 + # via aiohttp +attrs==25.3.0 + # via + # aiohttp + # dm-tree + # jsonlines + # rerun-sdk +av==15.0.0 + # via lerobot +blinker==1.9.0 + # via flask +certifi==2025.7.14 + # via + # requests + # sentry-sdk +cffi==1.17.1 + # via pymunk +cfgv==3.4.0 + # via pre-commit +charset-normalizer==3.4.2 + # via requests +click==8.2.1 + # via + # flask + # wandb +cloudpickle==3.1.1 + # via gymnasium +cmake==4.0.3 + # via lerobot +cmeel==0.57.3 + # via + # cmeel-assimp + # cmeel-boost + # cmeel-console-bridge + # cmeel-octomap + # cmeel-qhull + # cmeel-tinyxml2 + # cmeel-urdfdom + # cmeel-zlib + # coal-library + # eigenpy + # eiquadprog + # pin + # placo + # rhoban-cmeel-jsoncpp +cmeel-assimp==5.4.3.1 + # via coal-library +cmeel-boost==1.87.0.1 + # via + # coal-library + # eigenpy + # eiquadprog + # pin +cmeel-console-bridge==1.0.2.3 + # via cmeel-urdfdom +cmeel-octomap==1.10.0 + # via coal-library +cmeel-qhull==8.0.2.1 + # via coal-library +cmeel-tinyxml2==10.0.0 + # via cmeel-urdfdom +cmeel-urdfdom==4.0.1 + # via pin +cmeel-zlib==1.3.1 + # via cmeel-assimp +coal-library==3.0.1 + # via pin +contourpy==1.3.2 + # via matplotlib +coverage[toml]==7.10.1 + # via pytest-cov +cycler==0.12.1 + # via matplotlib +datasets==3.6.0 + # via lerobot +debugpy==1.8.15 + # via lerobot +decorator==5.2.1 + # via ipython +deepdiff==8.5.0 + # via lerobot +diffusers==0.34.0 + # via lerobot +dill==0.3.8 + # via + # datasets + # multiprocess +distlib==0.4.0 + # via virtualenv +dm-control==1.0.14 + # via gym-aloha +dm-env==1.6 + # via dm-control +dm-tree==0.1.9 + # via + # dm-control + # dm-env +docopt==0.6.2 + # via num2words +draccus==0.10.0 + # via lerobot +dynamixel-sdk==3.7.31 + # via lerobot +eigenpy==3.10.3 + # via coal-library +einops==0.8.1 + # via lerobot +eiquadprog==1.2.9 + # via placo +evdev==1.9.2 + # via pynput +exceptiongroup==1.3.0 + # via + # ipython + # pytest +executing==2.2.0 + # via stack-data +farama-notifications==0.0.4 + # via gymnasium +feetech-servo-sdk==1.0.0 + # via lerobot +filelock==3.18.0 + # via + # datasets + # diffusers + # huggingface-hub + # torch + # transformers + # virtualenv +flask==3.1.1 + # via lerobot +fonttools==4.59.0 + # via matplotlib +frozenlist==1.7.0 + # via + # aiohttp + # aiosignal +fsspec[http]==2025.3.0 + # via + # datasets + # huggingface-hub + # torch +gitdb==4.0.12 + # via gitpython +gitpython==3.1.45 + # via wandb +glfw==2.9.0 + # via + # dm-control + # mujoco +grpcio==1.73.1 + # via + # grpcio-tools + # lerobot +grpcio-tools==1.73.1 + # via lerobot +gym-aloha==0.1.1 + # via lerobot +gym-hil==0.1.10 + # via lerobot +gym-pusht==0.1.5 + # via lerobot +gym-xarm==0.1.1 + # via lerobot +gymnasium==0.29.1 + # via + # gym-aloha + # gym-hil + # gym-pusht + # gym-xarm + # gymnasium-robotics + # lerobot + # pettingzoo +gymnasium-robotics==1.2.4 + # via gym-xarm +hf-transfer==0.1.9 + # via huggingface-hub +hf-xet==1.1.5 + # via huggingface-hub +hidapi==0.14.0.post4 + # via + # gym-hil + # lerobot +huggingface-hub[cli,hf-transfer]==0.34.3 + # via + # accelerate + # datasets + # diffusers + # lerobot + # tokenizers + # transformers +identify==2.6.12 + # via pre-commit +idna==3.10 + # via + # requests + # yarl +imageio[ffmpeg]==2.37.0 + # via + # gym-aloha + # gym-hil + # gymnasium-robotics + # lerobot + # scikit-image +imageio-ffmpeg==0.6.0 + # via imageio +importlib-metadata==8.7.0 + # via diffusers +iniconfig==2.1.0 + # via pytest +inquirerpy==0.3.4 + # via huggingface-hub +ipython==8.37.0 + # via meshcat +ischedule==1.2.7 + # via placo +itsdangerous==2.2.0 + # via flask +jedi==0.19.2 + # via ipython +jinja2==3.1.6 + # via + # flask + # gymnasium-robotics + # torch +jsonlines==4.0.0 + # via lerobot +kiwisolver==1.4.8 + # via matplotlib +labmaze==1.0.6 + # via dm-control +lazy-loader==0.4 + # via scikit-image +lxml==6.0.0 + # via dm-control +markupsafe==3.0.2 + # via + # flask + # jinja2 + # werkzeug +matplotlib==3.10.5 + # via lerobot +matplotlib-inline==0.1.7 + # via ipython +mergedeep==1.3.4 + # via draccus +meshcat==0.3.2 + # via placo +mock-serial==0.0.1 + # via lerobot +mpmath==1.3.0 + # via sympy +mujoco==2.3.7 + # via + # dm-control + # gym-aloha + # gym-hil + # gym-xarm + # gymnasium-robotics +multidict==6.6.3 + # via + # aiohttp + # yarl +multiprocess==0.70.16 + # via datasets +mypy-extensions==1.1.0 + # via typing-inspect +networkx==3.4.2 + # via + # scikit-image + # torch +nodeenv==1.9.1 + # via pre-commit +num2words==0.5.14 + # via lerobot +numpy==2.2.6 + # via + # accelerate + # cmeel-boost + # contourpy + # datasets + # diffusers + # dm-control + # dm-env + # dm-tree + # gymnasium + # gymnasium-robotics + # imageio + # labmaze + # matplotlib + # meshcat + # mujoco + # opencv-python + # opencv-python-headless + # pandas + # pettingzoo + # rerun-sdk + # scikit-image + # scipy + # shapely + # tifffile + # torchvision + # transformers +nvidia-cublas-cu12==12.6.4.1 + # via + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 + # torch +nvidia-cuda-cupti-cu12==12.6.80 + # via torch +nvidia-cuda-nvrtc-cu12==12.6.77 + # via torch +nvidia-cuda-runtime-cu12==12.6.77 + # via torch +nvidia-cudnn-cu12==9.5.1.17 + # via torch +nvidia-cufft-cu12==11.3.0.4 + # via torch +nvidia-cufile-cu12==1.11.1.6 + # via torch +nvidia-curand-cu12==10.3.7.77 + # via torch +nvidia-cusolver-cu12==11.7.1.2 + # via torch +nvidia-cusparse-cu12==12.5.4.2 + # via + # nvidia-cusolver-cu12 + # torch +nvidia-cusparselt-cu12==0.6.3 + # via torch +nvidia-nccl-cu12==2.26.2 + # via torch +nvidia-nvjitlink-cu12==12.6.85 + # via + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 + # torch +nvidia-nvtx-cu12==12.6.77 + # via torch +opencv-python==4.12.0.88 + # via gym-pusht +opencv-python-headless==4.12.0.88 + # via lerobot +orderly-set==5.5.0 + # via deepdiff +packaging==25.0 + # via + # accelerate + # datasets + # huggingface-hub + # lazy-loader + # lerobot + # matplotlib + # pytest + # scikit-image + # transformers + # wandb +pandas==2.3.1 + # via + # datasets + # lerobot +parso==0.8.4 + # via jedi +pettingzoo==1.24.3 + # via gymnasium-robotics +pexpect==4.9.0 + # via ipython +pfzy==0.3.4 + # via inquirerpy +pillow==11.3.0 + # via + # diffusers + # imageio + # matplotlib + # meshcat + # rerun-sdk + # scikit-image + # torchvision +pin==3.4.0 + # via placo +placo==0.9.14 + # via lerobot +platformdirs==4.3.8 + # via + # virtualenv + # wandb +pluggy==1.6.0 + # via + # pytest + # pytest-cov +pre-commit==4.2.0 + # via lerobot +prompt-toolkit==3.0.51 + # via + # inquirerpy + # ipython +propcache==0.3.2 + # via + # aiohttp + # yarl +protobuf==6.31.0 + # via + # dm-control + # grpcio-tools + # lerobot + # wandb +psutil==7.0.0 + # via + # accelerate + # imageio +ptyprocess==0.7.0 + # via pexpect +pure-eval==0.2.3 + # via stack-data +pyarrow==21.0.0 + # via + # datasets + # rerun-sdk +pycparser==2.22 + # via cffi +pydantic==2.11.7 + # via wandb +pydantic-core==2.33.2 + # via pydantic +pygame==2.6.1 + # via + # gym-hil + # gym-pusht + # lerobot +pygments==2.19.2 + # via + # ipython + # pytest +pymunk==6.11.1 + # via + # gym-pusht + # lerobot +pyngrok==7.2.12 + # via meshcat +pynput==1.8.1 + # via + # gym-hil + # lerobot +pyopengl==3.1.9 + # via + # dm-control + # mujoco +pyparsing==3.2.3 + # via + # dm-control + # matplotlib +pyrealsense2==2.56.5.9235 + # via lerobot +pyserial==3.5 + # via + # dynamixel-sdk + # feetech-servo-sdk + # lerobot +pytest==8.4.1 + # via + # lerobot + # pytest-cov + # pytest-timeout +pytest-cov==6.2.1 + # via lerobot +pytest-timeout==2.4.0 + # via lerobot +python-dateutil==2.9.0.post0 + # via + # matplotlib + # pandas +python-xlib==0.33 + # via pynput +pytz==2025.2 + # via pandas +pyyaml==6.0.2 + # via + # accelerate + # datasets + # draccus + # huggingface-hub + # pre-commit + # pyngrok + # pyyaml-include + # transformers + # wandb +pyyaml-include==1.4.1 + # via draccus +pyzmq==27.0.0 + # via + # lerobot + # meshcat +regex==2025.7.34 + # via + # diffusers + # transformers +requests==2.32.4 + # via + # datasets + # diffusers + # dm-control + # huggingface-hub + # transformers + # wandb +rerun-sdk==0.22.1 + # via lerobot +rhoban-cmeel-jsoncpp==1.9.4.9 + # via placo +safetensors==0.5.3 + # via + # accelerate + # diffusers + # lerobot + # transformers +scikit-image==0.25.2 + # via + # gym-pusht + # lerobot +scipy==1.15.3 + # via + # dm-control + # scikit-image +sentry-sdk==2.34.1 + # via wandb +shapely==2.1.1 + # via gym-pusht +six==1.17.0 + # via + # pynput + # python-dateutil + # python-xlib +smmap==5.0.2 + # via gitdb +stack-data==0.6.3 + # via ipython +sympy==1.14.0 + # via torch +termcolor==3.1.0 + # via lerobot +tifffile==2025.5.10 + # via scikit-image +tokenizers==0.21.4 + # via transformers +toml==0.10.2 + # via draccus +tomli==2.2.1 + # via + # cmeel + # coverage + # pytest +torch==2.7.1 + # via + # accelerate + # lerobot + # torchvision +torchcodec==0.5 + # via lerobot +torchvision==0.22.1 + # via lerobot +tornado==6.5.1 + # via meshcat +tqdm==4.67.1 + # via + # datasets + # dm-control + # huggingface-hub + # transformers +traitlets==5.14.3 + # via + # ipython + # matplotlib-inline +transformers==4.51.3 + # via lerobot +triton==3.3.1 + # via torch +typing-extensions==4.14.1 + # via + # aiosignal + # exceptiongroup + # gymnasium + # huggingface-hub + # ipython + # multidict + # pydantic + # pydantic-core + # rerun-sdk + # torch + # typing-inspect + # typing-inspection + # wandb +typing-inspect==0.9.0 + # via draccus +typing-inspection==0.4.1 + # via pydantic +tzdata==2025.2 + # via pandas +u-msgpack-python==2.8.0 + # via meshcat +urllib3==2.5.0 + # via + # requests + # sentry-sdk +virtualenv==20.32.0 + # via pre-commit +wandb==0.21.0 + # via lerobot +wcwidth==0.2.13 + # via prompt-toolkit +werkzeug==3.1.3 + # via flask +wrapt==1.17.2 + # via dm-tree +xxhash==3.5.0 + # via datasets +yarl==1.20.1 + # via aiohttp +zipp==3.23.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/requirements.in b/requirements.in new file mode 100644 index 000000000..272f7f540 --- /dev/null +++ b/requirements.in @@ -0,0 +1,9 @@ +# requirements.in + +# requirements-macos.txt was generated on macOS and is platform-specific (macOS 15.5 24F74 arm64). +# Darwin MacBook-Pro.local 24.5.0 Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:43 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T8132 arm64 + +# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.2 LTS x86_64). +# Linux mlerobot-linux 6.14.0-27-generic #27~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Jul 22 17:38:49 UTC 2 x86_64 x86_64 x86_64 GNU/Linux + +-e .[all] diff --git a/src/lerobot/__init__.py b/src/lerobot/__init__.py index 38d4e8644..9d3ed1893 100644 --- a/src/lerobot/__init__.py +++ b/src/lerobot/__init__.py @@ -170,7 +170,7 @@ available_datasets = sorted( # lists all available policies from `lerobot/policies` available_policies = ["act", "diffusion", "tdmpc", "vqbet"] -# lists all available robots from `lerobot/robot_devices/robots` +# lists all available robots from `lerobot/robots` available_robots = [ "koch", "koch_bimanual", @@ -179,13 +179,13 @@ available_robots = [ "so101", ] -# lists all available cameras from `lerobot/robot_devices/cameras` +# lists all available cameras from `lerobot/cameras` available_cameras = [ "opencv", "intelrealsense", ] -# lists all available motors from `lerobot/robot_devices/motors` +# lists all available motors from `lerobot/motors` available_motors = [ "dynamixel", "feetech", diff --git a/src/lerobot/calibrate.py b/src/lerobot/calibrate.py index 1e8bf4751..0aa61a2f9 100644 --- a/src/lerobot/calibrate.py +++ b/src/lerobot/calibrate.py @@ -18,7 +18,7 @@ Helper to recalibrate your device (robot or teleoperator). Example: ```shell -python -m lerobot.calibrate \ +lerobot-calibrate \ --teleop.type=so100_leader \ --teleop.port=/dev/tty.usbmodem58760431551 \ --teleop.id=blue @@ -82,5 +82,9 @@ def calibrate(cfg: CalibrateConfig): device.disconnect() -if __name__ == "__main__": +def main(): calibrate() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/cameras/opencv/camera_opencv.py b/src/lerobot/cameras/opencv/camera_opencv.py index 7ad9988cc..3665a909f 100644 --- a/src/lerobot/cameras/opencv/camera_opencv.py +++ b/src/lerobot/cameras/opencv/camera_opencv.py @@ -60,7 +60,7 @@ class OpenCVCamera(Camera): or port changes, especially on Linux. Use the provided utility script to find available camera indices or paths: ```bash - python -m lerobot.find_cameras opencv + lerobot-find-cameras opencv ``` The camera's default settings (FPS, resolution, color mode) are used unless @@ -165,8 +165,7 @@ class OpenCVCamera(Camera): self.videocapture.release() self.videocapture = None raise ConnectionError( - f"Failed to open {self}." - f"Run `python -m lerobot.find_cameras opencv` to find available cameras." + f"Failed to open {self}.Run `lerobot-find-cameras opencv` to find available cameras." ) self._configure_capture_settings() @@ -368,7 +367,7 @@ class OpenCVCamera(Camera): if requested_color_mode == ColorMode.RGB: processed_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]: + if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]: processed_image = cv2.rotate(processed_image, self.rotation) return processed_image diff --git a/src/lerobot/cameras/realsense/camera_realsense.py b/src/lerobot/cameras/realsense/camera_realsense.py index 74b055fa4..12ce89c91 100644 --- a/src/lerobot/cameras/realsense/camera_realsense.py +++ b/src/lerobot/cameras/realsense/camera_realsense.py @@ -51,7 +51,7 @@ class RealSenseCamera(Camera): Use the provided utility script to find available camera indices and default profiles: ```bash - python -m lerobot.find_cameras realsense + lerobot-find-cameras realsense ``` A `RealSenseCamera` instance requires a configuration object specifying the @@ -176,8 +176,7 @@ class RealSenseCamera(Camera): self.rs_profile = None self.rs_pipeline = None raise ConnectionError( - f"Failed to open {self}." - "Run `python -m lerobot.find_cameras realsense` to find available cameras." + f"Failed to open {self}.Run `lerobot-find-cameras realsense` to find available cameras." ) from e self._configure_capture_settings() @@ -434,7 +433,7 @@ class RealSenseCamera(Camera): if self.color_mode == ColorMode.BGR: processed_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) - if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]: + if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]: processed_image = cv2.rotate(processed_image, self.rotation) return processed_image diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index c5b2fa09e..f5fa727cf 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -27,6 +27,7 @@ from huggingface_hub.constants import CONFIG_NAME from huggingface_hub.errors import HfHubHTTPError from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.constants import ACTION, OBS_STATE from lerobot.optim.optimizers import OptimizerConfig from lerobot.optim.schedulers import LRSchedulerConfig from lerobot.utils.hub import HubMixin @@ -119,8 +120,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): @property def robot_state_feature(self) -> PolicyFeature | None: - for _, ft in self.input_features.items(): - if ft.type is FeatureType.STATE: + for ft_name, ft in self.input_features.items(): + if ft.type is FeatureType.STATE and ft_name == OBS_STATE: return ft return None @@ -137,8 +138,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): @property def action_feature(self) -> PolicyFeature | None: - for _, ft in self.output_features.items(): - if ft.type is FeatureType.ACTION: + for ft_name, ft in self.output_features.items(): + if ft.type is FeatureType.ACTION and ft_name == ACTION: return ft return None diff --git a/src/lerobot/envs/configs.py b/src/lerobot/envs/configs.py index ef381e9e7..35797c6ed 100644 --- a/src/lerobot/envs/configs.py +++ b/src/lerobot/envs/configs.py @@ -44,7 +44,7 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC): @EnvConfig.register_subclass("aloha") @dataclass class AlohaEnv(EnvConfig): - task: str = "AlohaInsertion-v0" + task: str | None = "AlohaInsertion-v0" fps: int = 50 episode_length: int = 400 obs_type: str = "pixels_agent_pos" @@ -82,7 +82,7 @@ class AlohaEnv(EnvConfig): @EnvConfig.register_subclass("pusht") @dataclass class PushtEnv(EnvConfig): - task: str = "PushT-v0" + task: str | None = "PushT-v0" fps: int = 10 episode_length: int = 300 obs_type: str = "pixels_agent_pos" @@ -124,7 +124,7 @@ class PushtEnv(EnvConfig): @EnvConfig.register_subclass("xarm") @dataclass class XarmEnv(EnvConfig): - task: str = "XarmLift-v0" + task: str | None = "XarmLift-v0" fps: int = 15 episode_length: int = 200 obs_type: str = "pixels_agent_pos" @@ -200,10 +200,10 @@ class HILSerlRobotEnvConfig(EnvConfig): wrapper: EnvTransformConfig | None = None fps: int = 10 name: str = "real_robot" - mode: str = None # Either "record", "replay", None + mode: str | None = None # Either "record", "replay", None repo_id: str | None = None dataset_root: str | None = None - task: str = "" + task: str | None = "" num_episodes: int = 10 # only for record mode episode: int = 0 device: str = "cuda" @@ -213,6 +213,7 @@ class HILSerlRobotEnvConfig(EnvConfig): # For the reward classifier, to record more positive examples after a success number_of_steps_after_success: int = 0 + @property def gym_kwargs(self) -> dict: return {} @@ -222,9 +223,8 @@ class HILSerlRobotEnvConfig(EnvConfig): class HILEnvConfig(EnvConfig): """Configuration for the HIL environment.""" - type: str = "hil" name: str = "PandaPickCube" - task: str = "PandaPickCubeKeyboard-v0" + task: str | None = "PandaPickCubeKeyboard-v0" use_viewer: bool = True gripper_penalty: float = 0.0 use_gamepad: bool = True @@ -252,7 +252,7 @@ class HILEnvConfig(EnvConfig): robot_config: RobotConfig | None = None teleop_config: TeleoperatorConfig | None = None wrapper: EnvTransformConfig | None = None - mode: str = None # Either "record", "replay", None + mode: str | None = None # Either "record", "replay", None repo_id: str | None = None dataset_root: str | None = None num_episodes: int = 10 # only for record mode diff --git a/src/lerobot/find_cameras.py b/src/lerobot/find_cameras.py index be8f272ee..ec8f5ff30 100644 --- a/src/lerobot/find_cameras.py +++ b/src/lerobot/find_cameras.py @@ -20,7 +20,7 @@ Helper to find the camera devices available in your system. Example: ```shell -python -m lerobot.find_cameras +lerobot-find-cameras ``` """ @@ -286,7 +286,7 @@ def save_images_from_all_cameras( print(f"Image capture finished. Images saved to {output_dir}") -if __name__ == "__main__": +def main(): parser = argparse.ArgumentParser( description="Unified camera utility script for listing cameras and capturing images." ) @@ -313,3 +313,7 @@ if __name__ == "__main__": ) args = parser.parse_args() save_images_from_all_cameras(**vars(args)) + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/find_port.py b/src/lerobot/find_port.py index cf0282507..e32b9cb99 100644 --- a/src/lerobot/find_port.py +++ b/src/lerobot/find_port.py @@ -18,7 +18,7 @@ Helper to find the USB port associated with your MotorsBus. Example: ```shell -python -m lerobot.find_port +lerobot-find-port ``` """ @@ -61,5 +61,9 @@ def find_port(): raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).") -if __name__ == "__main__": +def main(): find_port() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/motors/dynamixel/tables.py b/src/lerobot/motors/dynamixel/tables.py index 8b67bbf38..5417d8cee 100644 --- a/src/lerobot/motors/dynamixel/tables.py +++ b/src/lerobot/motors/dynamixel/tables.py @@ -107,6 +107,8 @@ X_SERIES_ENCODINGS_TABLE = { "Goal_PWM": X_SERIES_CONTROL_TABLE["Goal_PWM"][1], "Goal_Current": X_SERIES_CONTROL_TABLE["Goal_Current"][1], "Goal_Velocity": X_SERIES_CONTROL_TABLE["Goal_Velocity"][1], + "Goal_Position": X_SERIES_CONTROL_TABLE["Goal_Position"][1], + "Present_Position": X_SERIES_CONTROL_TABLE["Present_Position"][1], "Present_PWM": X_SERIES_CONTROL_TABLE["Present_PWM"][1], "Present_Current": X_SERIES_CONTROL_TABLE["Present_Current"][1], "Present_Velocity": X_SERIES_CONTROL_TABLE["Present_Velocity"][1], diff --git a/src/lerobot/motors/motors_bus.py b/src/lerobot/motors/motors_bus.py index 597bcd3c4..97830fc35 100644 --- a/src/lerobot/motors/motors_bus.py +++ b/src/lerobot/motors/motors_bus.py @@ -222,7 +222,7 @@ class MotorsBus(abc.ABC): A MotorsBus subclass instance requires a port (e.g. `FeetechMotorsBus(port="/dev/tty.usbmodem575E0031751"`)). To find the port, you can run our utility script: ```bash - python -m lerobot.find_port.py + lerobot-find-port.py >>> Finding all available ports for the MotorsBus. >>> ["/dev/tty.usbmodem575E0032081", "/dev/tty.usbmodem575E0031751"] >>> Remove the usb cable from your MotorsBus and press Enter when done. @@ -446,7 +446,7 @@ class MotorsBus(abc.ABC): except (FileNotFoundError, OSError, serial.SerialException) as e: raise ConnectionError( f"\nCould not connect on port '{self.port}'. Make sure you are using the correct port." - "\nTry running `python -m lerobot.find_port`\n" + "\nTry running `lerobot-find-port`\n" ) from e @abc.abstractmethod diff --git a/src/lerobot/policies/act/README.md b/src/lerobot/policies/act/README.md new file mode 120000 index 000000000..046020098 --- /dev/null +++ b/src/lerobot/policies/act/README.md @@ -0,0 +1 @@ +../../../../docs/source/policy_act_README.md \ No newline at end of file diff --git a/src/lerobot/policies/diffusion/README.md b/src/lerobot/policies/diffusion/README.md new file mode 120000 index 000000000..d332d79c8 --- /dev/null +++ b/src/lerobot/policies/diffusion/README.md @@ -0,0 +1 @@ +../../../../docs/source/policy_diffusion_README.md \ No newline at end of file diff --git a/src/lerobot/policies/diffusion/configuration_diffusion.py b/src/lerobot/policies/diffusion/configuration_diffusion.py index ce2de7052..54569434a 100644 --- a/src/lerobot/policies/diffusion/configuration_diffusion.py +++ b/src/lerobot/policies/diffusion/configuration_diffusion.py @@ -217,12 +217,13 @@ class DiffusionConfig(PreTrainedConfig): ) # Check that all input images have the same shape. - first_image_key, first_image_ft = next(iter(self.image_features.items())) - for key, image_ft in self.image_features.items(): - if image_ft.shape != first_image_ft.shape: - raise ValueError( - f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match." - ) + if len(self.image_features) > 0: + first_image_key, first_image_ft = next(iter(self.image_features.items())) + for key, image_ft in self.image_features.items(): + if image_ft.shape != first_image_ft.shape: + raise ValueError( + f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match." + ) @property def observation_delta_indices(self) -> list: diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 24b273967..85d4d5981 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -133,11 +133,15 @@ class DiffusionPolicy(PreTrainedPolicy): "horizon" may not the best name to describe what the variable actually means, because this period is actually measured from the first observation which (if `n_obs_steps` > 1) happened in the past. """ + # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out + if ACTION in batch: + batch.pop(ACTION) + batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) - # Note: It's important that this happens after stacking the images into a single key. + # NOTE: It's important that this happens after stacking the images into a single key. self._queues = populate_queues(self._queues, batch) if len(self._queues[ACTION]) == 0: @@ -284,7 +288,7 @@ class DiffusionModel(nn.Module): "observation.images": (B, n_obs_steps, num_cameras, C, H, W) AND/OR - "observation.environment_state": (B, environment_dim) + "observation.environment_state": (B, n_obs_steps, environment_dim) } """ batch_size, n_obs_steps = batch["observation.state"].shape[:2] @@ -311,7 +315,7 @@ class DiffusionModel(nn.Module): "observation.images": (B, n_obs_steps, num_cameras, C, H, W) AND/OR - "observation.environment_state": (B, environment_dim) + "observation.environment_state": (B, n_obs_steps, environment_dim) "action": (B, horizon, action_dim) "action_is_pad": (B, horizon) diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 11feca964..de41e2bd4 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -30,7 +30,7 @@ pip install -e ".[pi0]" Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`): ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.path=lerobot/pi0 \ --dataset.repo_id=danaaubakirova/koch_test ``` @@ -38,7 +38,7 @@ python -m lerobot.scripts.train \ Example of finetuning the pi0 neural network with PaliGemma and expert Gemma pretrained with VLM default parameters before pi0 finetuning: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.type=pi0 \ --dataset.repo_id=danaaubakirova/koch_test ``` @@ -66,7 +66,8 @@ from lerobot.policies.pi0.paligemma_with_expert import ( PaliGemmaWithExpertModel, ) from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.utils.utils import get_safe_dtype +from lerobot.policies.utils import log_model_loading_keys +from lerobot.utils.utils import get_safe_dtype, init_logging def create_sinusoidal_pos_embedding( @@ -90,12 +91,6 @@ def create_sinusoidal_pos_embedding( return pos_emb -def sample_beta(alpha, beta, bsize, device): - gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha) - gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta) - return gamma1 / (gamma1 + gamma2) - - def make_att_2d_masks(pad_masks, att_masks): """Copied from big_vision. @@ -258,6 +253,99 @@ class PI0Policy(PreTrainedPolicy): """This should be called whenever the environment is reset.""" self._action_queue = deque([], maxlen=self.config.n_action_steps) + @classmethod + def _transform_state_dict_keys(cls, state_dict: dict) -> dict: + """ + Transform state dict keys to match expected model structure. + + Transformations: + - model.paligemma_with_expert.paligemma.language_model.lm_head -> + model.paligemma_with_expert.paligemma.lm_head + - model.paligemma_with_expert.paligemma.language_model.model -> + model.paligemma_with_expert.paligemma.model.language_model + - model.paligemma_with_expert.paligemma.vision_tower -> + model.paligemma_with_expert.paligemma.model.vision_tower + - model.paligemma_with_expert.paligemma.multi_modal_projector -> + model.paligemma_with_expert.paligemma.model.multi_modal_projector + + Also handles tied weights between lm_head.weight and + embed_tokens.weight. + """ + import re + + transformed_dict = {} + + transformations = [ + ( + re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.lm_head"), + ".paligemma_with_expert.paligemma.lm_head", + ), + ( + re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.model"), + ".paligemma_with_expert.paligemma.model.language_model", + ), + ( + re.compile(r"\.paligemma_with_expert\.paligemma\.vision_tower"), + ".paligemma_with_expert.paligemma.model.vision_tower", + ), + ( + re.compile(r"\.paligemma_with_expert\.paligemma\.multi_modal_projector"), + ".paligemma_with_expert.paligemma.model.multi_modal_projector", + ), + ] + + for key, value in state_dict.items(): + new_key = key + for pattern, replacement in transformations: + new_key = pattern.sub(replacement, new_key) + transformed_dict[new_key] = value + + # Handle tied weights: lm_head.weight and embed_tokens.weight share memory + lm_head_key = None + embed_tokens_key = None + + for key in transformed_dict: + if key.endswith(".paligemma_with_expert.paligemma.lm_head.weight"): + lm_head_key = key + elif key.endswith(".paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"): + embed_tokens_key = key + if lm_head_key and embed_tokens_key: + break + + if lm_head_key and not embed_tokens_key: + embed_tokens_key = lm_head_key.replace( + ".lm_head.weight", ".model.language_model.embed_tokens.weight" + ) + transformed_dict[embed_tokens_key] = transformed_dict[lm_head_key] + elif embed_tokens_key and not lm_head_key: + lm_head_key = embed_tokens_key.replace( + ".model.language_model.embed_tokens.weight", ".lm_head.weight" + ) + transformed_dict[lm_head_key] = transformed_dict[embed_tokens_key] + + return transformed_dict + + @classmethod + def _load_as_safetensor( + cls, model: "PI0Policy", model_file: str, map_location: str, strict: bool + ) -> "PI0Policy": + """Override to apply key transformations before loading.""" + from safetensors.torch import load_file + + init_logging() + # Load the state dict from file safely + state_dict = load_file(model_file, device=map_location) + + # Apply key transformations + transformed_state_dict = cls._transform_state_dict_keys(state_dict) + + # Load the transformed state dict + msg = model.load_state_dict(transformed_state_dict, strict=strict) + + # Log message + log_model_loading_keys(msg.missing_keys, msg.unexpected_keys) + return model + def get_optim_params(self) -> dict: return self.parameters() @@ -515,9 +603,10 @@ class PI0FlowMatching(nn.Module): return noise def sample_time(self, bsize, device): - time_beta = sample_beta(1.5, 1.0, bsize, device) + beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0) + time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32) time = time_beta * 0.999 + 0.001 - return time.to(dtype=torch.float32, device=device) + return time def embed_prefix( self, images, img_masks, lang_tokens, lang_masks diff --git a/src/lerobot/policies/pi0fast/modeling_pi0fast.py b/src/lerobot/policies/pi0fast/modeling_pi0fast.py index d3903066c..88727b581 100644 --- a/src/lerobot/policies/pi0fast/modeling_pi0fast.py +++ b/src/lerobot/policies/pi0fast/modeling_pi0fast.py @@ -25,14 +25,14 @@ Disclaimer: It is not expected to perform as well as the original implementation Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`): ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.path=lerobot/pi0fast_base \ --dataset.repo_id=danaaubakirova/koch_test ``` Example of training the pi0+FAST neural network with from scratch: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.type=pi0fast \ --dataset.repo_id=danaaubakirova/koch_test ``` @@ -488,6 +488,8 @@ class PI0FAST(nn.Module): param.data = param.data.to(dtype=torch_precision) self.set_requires_grad() self.image_keys = self.config.image_features.keys() + # TODO: Remove this once we bump transformers to >4.52.0 because the attribute will be removed + # AttributeError: 'PaliGemmaConfig' object has no attribute 'ignore_index' self.ignore_index = self.pi0_paligemma.config.ignore_index self.padding_side = self.config.padding_side diff --git a/src/lerobot/policies/pretrained.py b/src/lerobot/policies/pretrained.py index d745c901c..2f69309c1 100644 --- a/src/lerobot/policies/pretrained.py +++ b/src/lerobot/policies/pretrained.py @@ -30,6 +30,7 @@ from torch import Tensor, nn from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.train import TrainPipelineConfig +from lerobot.policies.utils import log_model_loading_keys from lerobot.utils.hub import HubMixin T = TypeVar("T", bound="PreTrainedPolicy") @@ -128,18 +129,26 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC): @classmethod def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T: - if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): - load_model_as_safetensor(model, model_file, strict=strict) - if map_location != "cpu": - logging.warning( - "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors." - " This means that the model is loaded on 'cpu' first and then copied to the device." - " This leads to a slower loading time." - " Please update safetensors to version 0.4.3 or above for improved performance." - ) - model.to(map_location) - else: - safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) + # Create base kwargs + kwargs = {"strict": strict} + + # Add device parameter for newer versions that support it + if packaging.version.parse(safetensors.__version__) >= packaging.version.parse("0.4.3"): + kwargs["device"] = map_location + + # Load the model with appropriate kwargs + missing_keys, unexpected_keys = load_model_as_safetensor(model, model_file, **kwargs) + log_model_loading_keys(missing_keys, unexpected_keys) + + # For older versions, manually move to device if needed + if "device" not in kwargs and map_location != "cpu": + logging.warning( + "Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors." + " This means that the model is loaded on 'cpu' first and then copied to the device." + " This leads to a slower loading time." + " Please update safetensors to version 0.4.3 or above for improved performance." + ) + model.to(map_location) return model @abc.abstractmethod diff --git a/src/lerobot/policies/smolvla/README.md b/src/lerobot/policies/smolvla/README.md new file mode 120000 index 000000000..f8de40269 --- /dev/null +++ b/src/lerobot/policies/smolvla/README.md @@ -0,0 +1 @@ +../../../../docs/source/policy_smolvla_README.md \ No newline at end of file diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index a31e1b078..18f2fc58a 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -28,7 +28,7 @@ pip install -e ".[smolvla]" Example of finetuning the smolvla pretrained model (`smolvla_base`): ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.path=lerobot/smolvla_base \ --dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \ --batch_size=64 \ @@ -38,7 +38,7 @@ python -m lerobot.scripts.train \ Example of finetuning a smolVLA. SmolVLA is composed of a pretrained VLM, and an action expert. ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --policy.type=smolvla \ --dataset.repo_id=danaaubakirova/svla_so100_task1_v3 \ --batch_size=64 \ @@ -194,12 +194,6 @@ def create_sinusoidal_pos_embedding( return pos_emb -def sample_beta(alpha, beta, bsize, device): - gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha) - gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta) - return gamma1 / (gamma1 + gamma2) - - def make_att_2d_masks(pad_masks, att_masks): """Copied from big_vision. @@ -384,8 +378,13 @@ class SmolVLAPolicy(PreTrainedPolicy): return self.parameters() def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: + # TODO: Check if this for loop is needed. + # Context: In fact, self.queues contains only ACTION field, and in inference, we don't have action in the batch + # In the case of offline inference, we have the action in the batch + # that why without the k != ACTION check, it will raise an error because we are trying to stack + # on an empty container. for k in batch: - if k in self._queues: + if k in self._queues and k != ACTION: batch[k] = torch.stack(list(self._queues[k]), dim=1) images, img_masks = self.prepare_images(batch) @@ -631,7 +630,7 @@ class VLAFlowMatching(nn.Module): └──────────────────────────────┘ """ - def __init__(self, config): + def __init__(self, config: SmolVLAConfig): super().__init__() self.config = config @@ -685,9 +684,10 @@ class VLAFlowMatching(nn.Module): return noise def sample_time(self, bsize, device): - time_beta = sample_beta(1.5, 1.0, bsize, device) + beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0) + time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32) time = time_beta * 0.999 + 0.001 - return time.to(dtype=torch.float32, device=device) + return time def embed_prefix( self, images, img_masks, lang_tokens, lang_masks, state: torch.Tensor = None diff --git a/src/lerobot/policies/tdmpc/README.md b/src/lerobot/policies/tdmpc/README.md new file mode 120000 index 000000000..413ea87b8 --- /dev/null +++ b/src/lerobot/policies/tdmpc/README.md @@ -0,0 +1 @@ +../../../../docs/source/policy_tdmpc_README.md \ No newline at end of file diff --git a/src/lerobot/policies/tdmpc/modeling_tdmpc.py b/src/lerobot/policies/tdmpc/modeling_tdmpc.py index 664fe863d..7ba88e5e6 100644 --- a/src/lerobot/policies/tdmpc/modeling_tdmpc.py +++ b/src/lerobot/policies/tdmpc/modeling_tdmpc.py @@ -143,7 +143,12 @@ class TDMPCPolicy(PreTrainedPolicy): @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations.""" + # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out + if ACTION in batch: + batch.pop(ACTION) + batch = self.normalize_inputs(batch) + if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))] diff --git a/src/lerobot/policies/utils.py b/src/lerobot/policies/utils.py index 5659e8727..5a3994cdf 100644 --- a/src/lerobot/policies/utils.py +++ b/src/lerobot/policies/utils.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from collections import deque import torch @@ -71,3 +72,16 @@ def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple: with torch.inference_mode(): output = module(dummy_input) return tuple(output.shape) + + +def log_model_loading_keys(missing_keys: list[str], unexpected_keys: list[str]) -> None: + """Log missing and unexpected keys when loading a model. + + Args: + missing_keys (list[str]): Keys that were expected but not found. + unexpected_keys (list[str]): Keys that were found but not expected. + """ + if missing_keys: + logging.warning(f"Missing key(s) when loading model: {missing_keys}") + if unexpected_keys: + logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}") diff --git a/src/lerobot/policies/vqbet/README.md b/src/lerobot/policies/vqbet/README.md new file mode 120000 index 000000000..a4ae9291a --- /dev/null +++ b/src/lerobot/policies/vqbet/README.md @@ -0,0 +1 @@ +../../../../docs/source/policy_vqbet_README.md \ No newline at end of file diff --git a/src/lerobot/policies/vqbet/modeling_vqbet.py b/src/lerobot/policies/vqbet/modeling_vqbet.py index b271298a3..feb65bb4c 100644 --- a/src/lerobot/policies/vqbet/modeling_vqbet.py +++ b/src/lerobot/policies/vqbet/modeling_vqbet.py @@ -139,11 +139,14 @@ class VQBeTPolicy(PreTrainedPolicy): environment. It works by managing the actions in a queue and only calling `select_actions` when the queue is empty. """ - + # NOTE: for offline evaluation, we have action in the batch, so we need to pop it out + if ACTION in batch: + batch.pop(ACTION) batch = self.normalize_inputs(batch) batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + # NOTE: It's important that this happens after stacking the images into a single key. batch["observation.images"] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) - # Note: It's important that this happens after stacking the images into a single key. + self._queues = populate_queues(self._queues, batch) if not self.vqbet.action_head.vqvae_model.discretized.item(): diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py new file mode 100644 index 000000000..8dd244c27 --- /dev/null +++ b/src/lerobot/processor/__init__.py @@ -0,0 +1,54 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .device_processor import DeviceProcessor +from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor +from .observation_processor import VanillaObservationProcessor +from .pipeline import ( + ActionProcessor, + DoneProcessor, + EnvTransition, + IdentityProcessor, + InfoProcessor, + ObservationProcessor, + ProcessorStep, + ProcessorStepRegistry, + RewardProcessor, + RobotProcessor, + TransitionKey, + TruncatedProcessor, +) +from .rename_processor import RenameProcessor + +__all__ = [ + "ActionProcessor", + "DeviceProcessor", + "DoneProcessor", + "EnvTransition", + "IdentityProcessor", + "InfoProcessor", + "NormalizerProcessor", + "UnnormalizerProcessor", + "ObservationProcessor", + "ProcessorStep", + "ProcessorStepRegistry", + "RenameProcessor", + "RewardProcessor", + "RobotProcessor", + "TransitionKey", + "TruncatedProcessor", + "VanillaObservationProcessor", +] diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py new file mode 100644 index 000000000..0f00bb470 --- /dev/null +++ b/src/lerobot/processor/device_processor.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any + +import torch + +from lerobot.configs.types import PolicyFeature +from lerobot.processor.pipeline import EnvTransition, TransitionKey +from lerobot.utils.utils import get_safe_torch_device + + +@dataclass +class DeviceProcessor: + """Processes transitions by moving tensors to the specified device. + + This processor ensures that all tensors in the transition are moved to the + specified device (CPU or GPU) before they are returned. + """ + + device: torch.device = "cpu" + + def __post_init__(self): + self.device = get_safe_torch_device(self.device) + self.non_blocking = "cuda" in str(self.device) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + # Create a copy of the transition + new_transition = transition.copy() + + # Process observation tensors + observation = transition.get(TransitionKey.OBSERVATION) + if observation is not None: + new_observation = { + k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v + for k, v in observation.items() + } + new_transition[TransitionKey.OBSERVATION] = new_observation + + # Process action tensor + action = transition.get(TransitionKey.ACTION) + if action is not None and isinstance(action, torch.Tensor): + new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking) + + # Process reward tensor + reward = transition.get(TransitionKey.REWARD) + if reward is not None and isinstance(reward, torch.Tensor): + new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking) + + # Process done tensor + done = transition.get(TransitionKey.DONE) + if done is not None and isinstance(done, torch.Tensor): + new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking) + + # Process truncated tensor + truncated = transition.get(TransitionKey.TRUNCATED) + if truncated is not None and isinstance(truncated, torch.Tensor): + new_transition[TransitionKey.TRUNCATED] = truncated.to( + self.device, non_blocking=self.non_blocking + ) + + return new_transition + + def get_config(self) -> dict[str, Any]: + """Return configuration for serialization.""" + return {"device": self.device} + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py new file mode 100644 index 000000000..14628727f --- /dev/null +++ b/src/lerobot/processor/normalize_processor.py @@ -0,0 +1,331 @@ +from __future__ import annotations + +from collections.abc import Mapping +from dataclasses import dataclass, field +from typing import Any + +import numpy as np +import torch +from torch import Tensor + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey + + +def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]: + """Convert numpy arrays and other types to torch tensors.""" + tensor_stats: dict[str, dict[str, Tensor]] = {} + for key, sub in stats.items(): + tensor_stats[key] = {} + for stat_name, value in sub.items(): + if isinstance(value, np.ndarray): + tensor_val = torch.from_numpy(value.astype(np.float32)) + elif isinstance(value, torch.Tensor): + tensor_val = value.to(dtype=torch.float32) + elif isinstance(value, (int, float, list, tuple)): + tensor_val = torch.tensor(value, dtype=torch.float32) + else: + raise TypeError(f"Unsupported type for stats['{key}']['{stat_name}']: {type(value)}") + tensor_stats[key][stat_name] = tensor_val + return tensor_stats + + +@dataclass +@ProcessorStepRegistry.register(name="normalizer_processor") +class NormalizerProcessor: + """Normalizes observations and actions in a single processor step. + + This processor handles normalization of both observation and action tensors + using either mean/std normalization or min/max scaling to a [-1, 1] range. + + For each tensor key in the stats dictionary, the processor will: + - Use mean/std normalization if those statistics are provided: (x - mean) / std + - Use min/max scaling if those statistics are provided: 2 * (x - min) / (max - min) - 1 + + The processor can be configured to normalize only specific keys by setting + the normalize_keys parameter. + """ + + # Features and normalisation map are mandatory to match the design of normalize.py + features: dict[str, PolicyFeature] + norm_map: dict[FeatureType, NormalizationMode] + + # Pre-computed statistics coming from dataset.meta.stats for instance. + stats: dict[str, dict[str, Any]] | None = None + + # Explicit subset of keys to normalise. If ``None`` every key (except + # "action") found in ``stats`` will be normalised. Using a ``set`` makes + # membership checks O(1). + normalize_keys: set[str] | None = None + + eps: float = 1e-8 + + _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) + + @classmethod + def from_lerobot_dataset( + cls, + dataset: LeRobotDataset, + features: dict[str, PolicyFeature], + norm_map: dict[FeatureType, NormalizationMode], + *, + normalize_keys: set[str] | None = None, + eps: float = 1e-8, + ) -> NormalizerProcessor: + """Factory helper that pulls statistics from a :class:`LeRobotDataset`. + + The features and norm_map parameters are mandatory to match the design + pattern used in normalize.py. + """ + + return cls( + features=features, + norm_map=norm_map, + stats=dataset.meta.stats, + normalize_keys=normalize_keys, + eps=eps, + ) + + def __post_init__(self): + # Handle deserialization from JSON config + if self.features and isinstance(list(self.features.values())[0], dict): + # Features came from JSON - need to reconstruct PolicyFeature objects + reconstructed_features = {} + for key, ft_dict in self.features.items(): + reconstructed_features[key] = PolicyFeature( + type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) + ) + self.features = reconstructed_features + + if self.norm_map and isinstance(list(self.norm_map.keys())[0], str): + # norm_map came from JSON - need to reconstruct enum keys and values + reconstructed_norm_map = {} + for ft_type_str, norm_mode_str in self.norm_map.items(): + reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) + self.norm_map = reconstructed_norm_map + + # Convert statistics once so we avoid repeated numpy→Tensor conversions + # during runtime. + self.stats = self.stats or {} + self._tensor_stats = _convert_stats_to_tensors(self.stats) + + # Ensure *normalize_keys* is a set for fast look-ups and compare by + # value later when returning the configuration. + if self.normalize_keys is not None and not isinstance(self.normalize_keys, set): + self.normalize_keys = set(self.normalize_keys) + + def _normalize_obs(self, observation): + if observation is None: + return None + + # Decide which keys should be normalised for this call. + if self.normalize_keys is not None: + keys_to_norm = self.normalize_keys + else: + # Use feature map to skip action keys. + keys_to_norm = {k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION} + + processed = dict(observation) + for key in keys_to_norm: + if key not in processed or key not in self._tensor_stats: + continue + + orig_val = processed[key] + tensor = ( + orig_val.to(dtype=torch.float32) + if isinstance(orig_val, torch.Tensor) + else torch.as_tensor(orig_val, dtype=torch.float32) + ) + stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()} + + if "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + processed[key] = (tensor - mean) / (std + self.eps) + elif "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 + return processed + + def _normalize_action(self, action): + if action is None or "action" not in self._tensor_stats: + return action + + tensor = ( + action.to(dtype=torch.float32) + if isinstance(action, torch.Tensor) + else torch.as_tensor(action, dtype=torch.float32) + ) + stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()} + if "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + return (tensor - mean) / (std + self.eps) + if "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1 + raise ValueError("Action stats must contain either ('mean','std') or ('min','max')") + + def __call__(self, transition: EnvTransition) -> EnvTransition: + observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION)) + action = self._normalize_action(transition.get(TransitionKey.ACTION)) + + # Create a new transition with normalized values + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = observation + new_transition[TransitionKey.ACTION] = action + return new_transition + + def get_config(self) -> dict[str, Any]: + config = { + "eps": self.eps, + "features": { + key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items() + }, + "norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()}, + } + if self.normalize_keys is not None: + # Serialise as a list for YAML / JSON friendliness + config["normalize_keys"] = sorted(self.normalize_keys) + return config + + def state_dict(self) -> dict[str, Tensor]: + flat = {} + for key, sub in self._tensor_stats.items(): + for stat_name, tensor in sub.items(): + flat[f"{key}.{stat_name}"] = tensor + return flat + + def load_state_dict(self, state: Mapping[str, Tensor]) -> None: + self._tensor_stats.clear() + for flat_key, tensor in state.items(): + key, stat_name = flat_key.rsplit(".", 1) + self._tensor_stats.setdefault(key, {})[stat_name] = tensor + + def reset(self): + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +@dataclass +@ProcessorStepRegistry.register(name="unnormalizer_processor") +class UnnormalizerProcessor: + """Inverse normalisation for observations and actions. + + Exactly mirrors :class:`NormalizerProcessor` but applies the inverse + transform. + """ + + features: dict[str, PolicyFeature] + norm_map: dict[FeatureType, NormalizationMode] + stats: dict[str, dict[str, Any]] | None = None + + _tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False) + + @classmethod + def from_lerobot_dataset( + cls, + dataset: LeRobotDataset, + features: dict[str, PolicyFeature], + norm_map: dict[FeatureType, NormalizationMode], + ) -> UnnormalizerProcessor: + return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats) + + def __post_init__(self): + # Handle deserialization from JSON config + if self.features and isinstance(list(self.features.values())[0], dict): + # Features came from JSON - need to reconstruct PolicyFeature objects + reconstructed_features = {} + for key, ft_dict in self.features.items(): + reconstructed_features[key] = PolicyFeature( + type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"]) + ) + self.features = reconstructed_features + + if self.norm_map and isinstance(list(self.norm_map.keys())[0], str): + # norm_map came from JSON - need to reconstruct enum keys and values + reconstructed_norm_map = {} + for ft_type_str, norm_mode_str in self.norm_map.items(): + reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str) + self.norm_map = reconstructed_norm_map + + self.stats = self.stats or {} + self._tensor_stats = _convert_stats_to_tensors(self.stats) + + def _unnormalize_obs(self, observation): + if observation is None: + return None + keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION] + processed = dict(observation) + for key in keys: + if key not in processed or key not in self._tensor_stats: + continue + orig_val = processed[key] + tensor = ( + orig_val.to(dtype=torch.float32) + if isinstance(orig_val, torch.Tensor) + else torch.as_tensor(orig_val, dtype=torch.float32) + ) + stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()} + if "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + processed[key] = tensor * std + mean + elif "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val + return processed + + def _unnormalize_action(self, action): + if action is None or "action" not in self._tensor_stats: + return action + tensor = ( + action.to(dtype=torch.float32) + if isinstance(action, torch.Tensor) + else torch.as_tensor(action, dtype=torch.float32) + ) + stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()} + if "mean" in stats and "std" in stats: + mean, std = stats["mean"], stats["std"] + return tensor * std + mean + if "min" in stats and "max" in stats: + min_val, max_val = stats["min"], stats["max"] + return (tensor + 1) / 2 * (max_val - min_val) + min_val + raise ValueError("Action stats must contain either ('mean','std') or ('min','max')") + + def __call__(self, transition: EnvTransition) -> EnvTransition: + observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION)) + action = self._unnormalize_action(transition.get(TransitionKey.ACTION)) + + # Create a new transition with unnormalized values + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = observation + new_transition[TransitionKey.ACTION] = action + return new_transition + + def get_config(self) -> dict[str, Any]: + return { + "features": { + key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items() + }, + "norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()}, + } + + def state_dict(self) -> dict[str, Tensor]: + flat = {} + for key, sub in self._tensor_stats.items(): + for stat_name, tensor in sub.items(): + flat[f"{key}.{stat_name}"] = tensor + return flat + + def load_state_dict(self, state: Mapping[str, Tensor]) -> None: + self._tensor_stats.clear() + for flat_key, tensor in state.items(): + key, stat_name = flat_key.rsplit(".", 1) + self._tensor_stats.setdefault(key, {})[stat_name] = tensor + + def reset(self): + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py new file mode 100644 index 000000000..7d63db238 --- /dev/null +++ b/src/lerobot/processor/observation_processor.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass + +import einops +import numpy as np +import torch +from torch import Tensor + +from lerobot.configs.types import PolicyFeature +from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.processor.pipeline import ObservationProcessor, ProcessorStepRegistry + + +@dataclass +@ProcessorStepRegistry.register(name="observation_processor") +class VanillaObservationProcessor(ObservationProcessor): + """ + Processes environment observations into the LeRobot format by handling both images and states. + + Image processing: + - Converts channel-last (H, W, C) images to channel-first (C, H, W) + - Normalizes uint8 images ([0, 255]) to float32 ([0, 1]) + - Adds a batch dimension if missing + - Supports single images and image dictionaries + + State processing: + - Maps 'environment_state' to observation.environment_state + - Maps 'agent_pos' to observation.state + - Converts numpy arrays to tensors + - Adds a batch dimension if missing + """ + + def _process_single_image(self, img: np.ndarray) -> Tensor: + """Process a single image array.""" + # Convert to tensor + img_tensor = torch.from_numpy(img) + + # Add batch dimension if needed + if img_tensor.ndim == 3: + img_tensor = img_tensor.unsqueeze(0) + + # Validate image format + _, h, w, c = img_tensor.shape + if not (c < h and c < w): + raise ValueError(f"Expected channel-last images, but got shape {img_tensor.shape}") + + if img_tensor.dtype != torch.uint8: + raise ValueError(f"Expected torch.uint8 images, but got {img_tensor.dtype}") + + # Convert to channel-first format + img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous() + + # Convert to float32 and normalize to [0, 1] + img_tensor = img_tensor.type(torch.float32) / 255.0 + + return img_tensor + + def _process_observation(self, observation): + """ + Processes both image and state observations. + """ + + processed_obs = observation.copy() + + if "pixels" in processed_obs: + pixels = processed_obs.pop("pixels") + + if isinstance(pixels, dict): + imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()} + else: + imgs = {OBS_IMAGE: pixels} + + for imgkey, img in imgs.items(): + processed_obs[imgkey] = self._process_single_image(img) + + if "environment_state" in processed_obs: + env_state_np = processed_obs.pop("environment_state") + env_state = torch.from_numpy(env_state_np).float() + if env_state.dim() == 1: + env_state = env_state.unsqueeze(0) + processed_obs[OBS_ENV_STATE] = env_state + + if "agent_pos" in processed_obs: + agent_pos_np = processed_obs.pop("agent_pos") + agent_pos = torch.from_numpy(agent_pos_np).float() + if agent_pos.dim() == 1: + agent_pos = agent_pos.unsqueeze(0) + processed_obs[OBS_STATE] = agent_pos + + return processed_obs + + def observation(self, observation): + return self._process_observation(observation) + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """Transforms feature keys to a standardized contract. + + This method handles several renaming patterns: + - Exact matches (e.g., 'pixels' -> 'OBS_IMAGE'). + - Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE'). + - Prefix matches (e.g., 'pixels.cam1' -> 'OBS_IMAGES.cam1'). + - Prefixed prefix matches (e.g., 'observation.pixels.cam1' -> 'OBS_IMAGES.cam1'). + - environment_state -> OBS_ENV_STATE, + - agent_pos -> OBS_STATE, + - observation.environment_state -> OBS_ENV_STATE, + - observation.agent_pos -> OBS_STATE + """ + exact_pairs = { + "pixels": OBS_IMAGE, + "environment_state": OBS_ENV_STATE, + "agent_pos": OBS_STATE, + } + + prefix_pairs = { + "pixels.": f"{OBS_IMAGES}.", + } + + for key in list(features.keys()): + matched_prefix = False + for old_prefix, new_prefix in prefix_pairs.items(): + prefixed_old = f"observation.{old_prefix}" + if key.startswith(prefixed_old): + suffix = key[len(prefixed_old) :] + features[f"{new_prefix}{suffix}"] = features.pop(key) + matched_prefix = True + break + + if key.startswith(old_prefix): + suffix = key[len(old_prefix) :] + features[f"{new_prefix}{suffix}"] = features.pop(key) + matched_prefix = True + break + + if matched_prefix: + continue + + for old, new in exact_pairs.items(): + if key == old or key == f"observation.{old}": + if key in features: + features[new] = features.pop(key) + break + + return features diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py new file mode 100644 index 000000000..6e1b2a2cb --- /dev/null +++ b/src/lerobot/processor/pipeline.py @@ -0,0 +1,1264 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import importlib +import json +import os +from collections.abc import Callable, Iterable, Sequence +from copy import deepcopy +from dataclasses import dataclass, field +from enum import Enum +from pathlib import Path +from typing import Any, Protocol, TypedDict + +import torch +from huggingface_hub import ModelHubMixin, hf_hub_download +from huggingface_hub.errors import HfHubHTTPError +from safetensors.torch import load_file, save_file + +from lerobot.configs.types import PolicyFeature + + +class TransitionKey(str, Enum): + """Keys for accessing EnvTransition dictionary components.""" + + # TODO(Steven): Use consts + OBSERVATION = "observation" + ACTION = "action" + REWARD = "reward" + DONE = "done" + TRUNCATED = "truncated" + INFO = "info" + COMPLEMENTARY_DATA = "complementary_data" + + +EnvTransition = TypedDict( + "EnvTransition", + { + TransitionKey.OBSERVATION.value: dict[str, Any] | None, + TransitionKey.ACTION.value: Any | torch.Tensor | None, + TransitionKey.REWARD.value: float | torch.Tensor | None, + TransitionKey.DONE.value: bool | torch.Tensor | None, + TransitionKey.TRUNCATED.value: bool | torch.Tensor | None, + TransitionKey.INFO.value: dict[str, Any] | None, + TransitionKey.COMPLEMENTARY_DATA.value: dict[str, Any] | None, + }, +) + + +class ProcessorStepRegistry: + """Registry for processor steps that enables saving/loading by name instead of module path.""" + + _registry: dict[str, type] = {} + + @classmethod + def register(cls, name: str = None): + """Decorator to register a processor step class. + + Args: + name: Optional registration name. If not provided, uses class name. + + Example: + @ProcessorStepRegistry.register("adaptive_normalizer") + class AdaptiveObservationNormalizer: + ... + """ + + def decorator(step_class: type) -> type: + registration_name = name if name is not None else step_class.__name__ + + if registration_name in cls._registry: + raise ValueError( + f"Processor step '{registration_name}' is already registered. " + f"Use a different name or unregister the existing one first." + ) + + cls._registry[registration_name] = step_class + # Store the registration name on the class for later reference + step_class._registry_name = registration_name + return step_class + + return decorator + + @classmethod + def get(cls, name: str) -> type: + """Get a registered processor step class by name. + + Args: + name: The registration name of the step. + + Returns: + The registered step class. + + Raises: + KeyError: If the step is not registered. + """ + if name not in cls._registry: + available = list(cls._registry.keys()) + raise KeyError( + f"Processor step '{name}' not found in registry. " + f"Available steps: {available}. " + f"Make sure the step is registered using @ProcessorStepRegistry.register()" + ) + return cls._registry[name] + + @classmethod + def unregister(cls, name: str) -> None: + """Remove a step from the registry.""" + cls._registry.pop(name, None) + + @classmethod + def list(cls) -> list[str]: + """List all registered step names.""" + return list(cls._registry.keys()) + + @classmethod + def clear(cls) -> None: + """Clear all registrations.""" + cls._registry.clear() + + +class ProcessorStep(Protocol): + """Structural typing interface for a single processor step. + + A step is any callable accepting a full `EnvTransition` dict and + returning a (possibly modified) dict of the same structure. Implementers + are encouraged—but not required—to expose the optional helper methods + listed below. When present, these hooks let `RobotProcessor` + automatically serialise the step's configuration and learnable state using + a safe-to-share JSON + SafeTensors format. + + + **Required**: + - ``__call__(transition: EnvTransition) -> EnvTransition`` + - ``feature_contract(features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]`` + + Optional helper protocol: + * ``get_config() -> dict[str, Any]`` – User-defined JSON-serializable + configuration and state. YOU decide what to save here. This is where all + non-tensor state goes (e.g., name, counter, threshold, window_size). + The config dict will be passed to your class constructor when loading. + * ``state_dict() -> dict[str, torch.Tensor]`` – PyTorch tensor state ONLY. + This is exclusively for torch.Tensor objects (e.g., learned weights, + running statistics as tensors). Never put simple Python types here. + * ``load_state_dict(state)`` – Inverse of ``state_dict``. Receives a dict + containing torch tensors only. + * ``reset()`` – Clear internal buffers at episode boundaries. + + Example separation: + - get_config(): {"name": "my_step", "learning_rate": 0.01, "window_size": 10} + - state_dict(): {"weights": torch.tensor(...), "running_mean": torch.tensor(...)} + """ + + def __call__(self, transition: EnvTransition) -> EnvTransition: ... + + def get_config(self) -> dict[str, Any]: ... + + def state_dict(self) -> dict[str, torch.Tensor]: ... + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: ... + + def reset(self) -> None: ... + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: ... + + +def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noqa: D401 + """Convert a *batch* dict coming from Learobot replay/dataset code into an + ``EnvTransition`` dictionary. + + The function maps well known keys to the EnvTransition structure. Missing keys are + filled with sane defaults (``None`` or ``0.0``/``False``). + + Keys recognised (case-sensitive): + + * "observation.*" (keys starting with "observation." are grouped into observation dict) + * "action" + * "next.reward" + * "next.done" + * "next.truncated" + * "info" + + Additional keys are ignored so that existing dataloaders can carry extra + metadata without breaking the processor. + """ + + # Extract observation keys + observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} + observation = observation_keys if observation_keys else None + + # Extract padding and task keys for complementary data + pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k} + task_key = {"task": batch["task"]} if "task" in batch else {} + complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {} + + transition: EnvTransition = { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: batch.get("action"), + TransitionKey.REWARD: batch.get("next.reward", 0.0), + TransitionKey.DONE: batch.get("next.done", False), + TransitionKey.TRUNCATED: batch.get("next.truncated", False), + TransitionKey.INFO: batch.get("info", {}), + TransitionKey.COMPLEMENTARY_DATA: complementary_data, + } + return transition + + +def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: # noqa: D401 + """Inverse of :pyfunc:`_default_batch_to_transition`. Returns a dict with + the canonical field names used throughout *LeRobot*. + """ + + batch = { + "action": transition.get(TransitionKey.ACTION), + "next.reward": transition.get(TransitionKey.REWARD, 0.0), + "next.done": transition.get(TransitionKey.DONE, False), + "next.truncated": transition.get(TransitionKey.TRUNCATED, False), + "info": transition.get(TransitionKey.INFO, {}), + } + + # Add padding and task data from complementary_data + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data: + pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k} + batch.update(pad_data) + + if "task" in complementary_data: + batch["task"] = complementary_data["task"] + + # Handle observation - flatten dict to observation.* keys if it's a dict + observation = transition.get(TransitionKey.OBSERVATION) + if isinstance(observation, dict): + batch.update(observation) + + return batch + + +@dataclass +class RobotProcessor(ModelHubMixin): + """ + Composable, debuggable post-processing processor for robot transitions. + + The class orchestrates an ordered collection of small, functional transforms—steps—executed + left-to-right on each incoming `EnvTransition`. It can process both `EnvTransition` dicts + and batch dictionaries, automatically converting between formats as needed. + + Args: + steps: Ordered list of processing steps executed on every call. Defaults to empty list. + name: Human-readable identifier that is persisted inside the JSON config. + Defaults to "RobotProcessor". + to_transition: Function to convert batch dict to EnvTransition dict. + Defaults to _default_batch_to_transition. + to_output: Function to convert EnvTransition dict to the desired output format. + Usually it is a batch dict or EnvTransition dict. + Defaults to _default_transition_to_batch. + before_step_hooks: List of hooks called before each step. Each hook receives the step + index and transition, and can optionally return a modified transition. + after_step_hooks: List of hooks called after each step. Each hook receives the step + index and transition, and can optionally return a modified transition. + + Hook Semantics: + - Hooks are executed sequentially in the order they were registered. There is no way to + reorder hooks after registration without creating a new pipeline. + - Hooks are for observation/monitoring only and DO NOT modify transitions. They are called + with the step index and current transition for logging, debugging, or monitoring purposes. + - All hooks for a given type (before/after) are executed for every step, or none at all if + an error occurs. There is no partial execution of hooks. + - Hooks should generally be stateless to maintain predictable behavior. If you need stateful + processing, consider implementing a proper ProcessorStep instead. + - To remove hooks, use the unregister methods. To remove steps, you must create a new pipeline. + - Hooks ALWAYS receive transitions in EnvTransition format, regardless of the input format + passed to __call__. This ensures consistent hook behavior whether processing batch dicts + or EnvTransition objects. + """ + + steps: Sequence[ProcessorStep] = field(default_factory=list) + name: str = "RobotProcessor" + + to_transition: Callable[[dict[str, Any]], EnvTransition] = field( + default_factory=lambda: _default_batch_to_transition, repr=False + ) + to_output: Callable[[EnvTransition], dict[str, Any] | EnvTransition] = field( + default_factory=lambda: _default_transition_to_batch, repr=False + ) + + # Processor-level hooks for observation/monitoring + # Hooks do not modify transitions - they are called for logging, debugging, or monitoring purposes + before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) + after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) + + def __call__(self, data: EnvTransition | dict[str, Any]): + """Process data through all steps. + + The method accepts either the classic EnvTransition dict or a batch dictionary + (like the ones returned by ReplayBuffer or LeRobotDataset). If a dict is supplied + it is first converted to the internal dict format using to_transition; after all + steps are executed the dict is transformed back into a batch dict with to_batch and the + result is returned – thereby preserving the caller's original data type. + + Args: + data: Either an EnvTransition dict or a batch dictionary to process. + + Returns: + The processed data in the same format as the input (EnvTransition or batch dict). + + Raises: + ValueError: If the transition is not a valid EnvTransition format. + """ + # Check if we need to convert back to batch format at the end + _, called_with_batch = self._prepare_transition(data) + + # Use step_through to get the iterator + step_iterator = self.step_through(data) + + # Get initial state (before any steps) + current_transition = next(step_iterator) + + # Process each step with hooks + for idx, next_transition in enumerate(step_iterator): + # Apply before hooks with current state (before step execution) + for hook in self.before_step_hooks: + hook(idx, current_transition) + + # Move to next state (after step execution) + current_transition = next_transition + + # Apply after hooks with updated state + for hook in self.after_step_hooks: + hook(idx, current_transition) + + # Convert back to original format if needed + return self.to_output(current_transition) if called_with_batch else current_transition + + def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]: + """Prepare and validate transition data for processing. + + Args: + data: Either an EnvTransition dict or a batch dictionary to process. + + Returns: + A tuple of (prepared_transition, called_with_batch_flag) + + Raises: + ValueError: If the transition is not a valid EnvTransition format. + """ + # Check if data is already an EnvTransition or needs conversion + if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()): + # It's a batch dict, convert it + called_with_batch = True + transition = self.to_transition(data) + else: + # It's already an EnvTransition + called_with_batch = False + transition = data + + # Basic validation + if not isinstance(transition, dict): + raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}") + + return transition, called_with_batch + + def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition]: + """Yield the intermediate results after each processor step. + + This is a low-level method that does NOT apply hooks. It simply executes each step + and yields the intermediate results. This allows users to debug the pipeline or + apply custom logic between steps if needed. + + Note: This method always yields EnvTransition objects regardless of input format. + If you need the results in the original input format, you'll need to convert them + using `to_output()`. + + Args: + data: Either an EnvTransition dict or a batch dictionary to process. + + Yields: + The intermediate EnvTransition results after each step. + """ + transition, _ = self._prepare_transition(data) + + # Yield initial state + yield transition + + # Process each step WITHOUT hooks (low-level method) + for processor_step in self.steps: + transition = processor_step(transition) + yield transition + + def _save_pretrained(self, save_directory: Path, **kwargs): + """Internal save method for ModelHubMixin compatibility.""" + # Extract config_filename from kwargs if provided + config_filename = kwargs.pop("config_filename", None) + self.save_pretrained(save_directory, config_filename=config_filename) + + def save_pretrained(self, save_directory: str | Path, config_filename: str | None = None, **kwargs): + """Serialize the processor definition and parameters to *save_directory*. + + Args: + save_directory: Directory where the processor will be saved. + config_filename: Optional custom config filename. If not provided, defaults to + "{self.name}.json" where self.name is sanitized for filesystem compatibility. + """ + os.makedirs(str(save_directory), exist_ok=True) + + # Sanitize processor name for use in filenames + import re + + # The huggingface hub does not allow special characters in the repo name, so we sanitize the name + sanitized_name = re.sub(r"[^a-zA-Z0-9_]", "_", self.name.lower()) + + # Use sanitized name for config if not provided + if config_filename is None: + config_filename = f"{sanitized_name}.json" + + config: dict[str, Any] = { + "name": self.name, + "steps": [], + } + + for step_index, processor_step in enumerate(self.steps): + # Check if step was registered + registry_name = getattr(processor_step.__class__, "_registry_name", None) + + step_entry: dict[str, Any] = {} + if registry_name: + # Use registry name for registered steps + step_entry["registry_name"] = registry_name + else: + # Fall back to full module path for unregistered steps + step_entry["class"] = ( + f"{processor_step.__class__.__module__}.{processor_step.__class__.__name__}" + ) + + if hasattr(processor_step, "get_config"): + step_entry["config"] = processor_step.get_config() + + if hasattr(processor_step, "state_dict"): + state = processor_step.state_dict() + if state: + # Clone tensors to avoid shared memory issues + # This ensures each tensor has its own memory allocation + # The reason is to avoid the following error: + # RuntimeError: Some tensors share memory, this will lead to duplicate memory on disk + # and potential differences when loading them again + # ------------------------------------------------------------------------------ + # Since the state_dict of processor will be light, we can just clone the tensors + # and save them to the disk. + cloned_state = {} + for key, tensor in state.items(): + cloned_state[key] = tensor.clone() + + # Include pipeline name and step index to ensure unique filenames + # This prevents conflicts when multiple processors are saved in the same directory + if registry_name: + state_filename = f"{sanitized_name}_step_{step_index}_{registry_name}.safetensors" + else: + state_filename = f"{sanitized_name}_step_{step_index}.safetensors" + + save_file(cloned_state, os.path.join(str(save_directory), state_filename)) + step_entry["state_file"] = state_filename + + config["steps"].append(step_entry) + + with open(os.path.join(str(save_directory), config_filename), "w") as file_pointer: + json.dump(config, file_pointer, indent=2) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: str | Path, + *, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict[str, str] | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + config_filename: str | None = None, + overrides: dict[str, Any] | None = None, + **kwargs, + ) -> RobotProcessor: + """Load a serialized processor from source (local path or Hugging Face Hub identifier). + + Args: + pretrained_model_name_or_path: Local path to a saved processor directory or Hugging Face Hub identifier + (e.g., "username/processor-name"). + config_filename: Optional specific config filename to load. If not provided, will: + - For local paths: look for any .json file in the directory (error if multiple found) + - For HF Hub: try common names ("processor.json", "preprocessor.json", "postprocessor.json") + overrides: Optional dictionary mapping step names to configuration overrides. + Keys must match exact step class names (for unregistered steps) or registry names + (for registered steps). Values are dictionaries containing parameter overrides + that will be merged with the saved configuration. This is useful for providing + non-serializable objects like environment instances. + + Returns: + A RobotProcessor instance loaded from the saved configuration. + + Raises: + ImportError: If a processor step class cannot be loaded or imported. + ValueError: If a step cannot be instantiated with the provided configuration. + KeyError: If an override key doesn't match any step in the saved configuration. + + Examples: + Basic loading: + ```python + processor = RobotProcessor.from_pretrained("path/to/processor") + ``` + + Loading specific config file: + ```python + processor = RobotProcessor.from_pretrained( + "username/multi-processor-repo", config_filename="preprocessor.json" + ) + ``` + + Loading with overrides for non-serializable objects: + ```python + import gym + + env = gym.make("CartPole-v1") + processor = RobotProcessor.from_pretrained( + "username/cartpole-processor", overrides={"ActionRepeatStep": {"env": env}} + ) + ``` + + Multiple overrides: + ```python + processor = RobotProcessor.from_pretrained( + "path/to/processor", + overrides={ + "CustomStep": {"param1": "new_value"}, + "device_processor": {"device": "cuda:1"}, # For registered steps + }, + ) + ``` + """ + # Use the local variable name 'source' for clarity + source = str(pretrained_model_name_or_path) + + if Path(source).is_dir(): + # Local path - use it directly + base_path = Path(source) + + if config_filename is None: + # Look for any .json file in the directory + json_files = list(base_path.glob("*.json")) + if len(json_files) == 0: + raise FileNotFoundError(f"No .json configuration files found in {source}") + elif len(json_files) > 1: + raise ValueError( + f"Multiple .json files found in {source}: {[f.name for f in json_files]}. " + f"Please specify which one to load using the config_filename parameter." + ) + config_filename = json_files[0].name + + with open(base_path / config_filename) as file_pointer: + loaded_config: dict[str, Any] = json.load(file_pointer) + else: + # Hugging Face Hub - download all required files + if config_filename is None: + # Try common config names + common_names = [ + "processor.json", + "preprocessor.json", + "postprocessor.json", + "robotprocessor.json", + ] + config_path = None + for name in common_names: + try: + config_path = hf_hub_download( + source, + name, + repo_type="model", + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + config_filename = name + break + except (FileNotFoundError, OSError, HfHubHTTPError): + # FileNotFoundError: local file issues + # OSError: network/system errors + # HfHubHTTPError: file not found on Hub (404) or other HTTP errors + continue + + if config_path is None: + raise FileNotFoundError( + f"No processor configuration file found in {source}. " + f"Tried: {common_names}. Please specify the config_filename parameter." + ) + else: + # Download specific config file + config_path = hf_hub_download( + source, + config_filename, + repo_type="model", + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + + with open(config_path) as file_pointer: + loaded_config = json.load(file_pointer) + + # Store downloaded files in the same directory as the config + base_path = Path(config_path).parent + + # Handle None overrides + if overrides is None: + overrides = {} + + # Validate that all override keys will be matched + override_keys = set(overrides.keys()) + + steps: list[ProcessorStep] = [] + for step_entry in loaded_config["steps"]: + # Check if step uses registry name or module path + if "registry_name" in step_entry: + # Load from registry + try: + step_class = ProcessorStepRegistry.get(step_entry["registry_name"]) + step_key = step_entry["registry_name"] + except KeyError as e: + raise ImportError(f"Failed to load processor step from registry. {str(e)}") from e + else: + # Fall back to module path loading for backward compatibility + full_class_path = step_entry["class"] + module_path, class_name = full_class_path.rsplit(".", 1) + + # Import the module containing the step class + try: + module = importlib.import_module(module_path) + step_class = getattr(module, class_name) + step_key = class_name + except (ImportError, AttributeError) as e: + raise ImportError( + f"Failed to load processor step '{full_class_path}'. " + f"Make sure the module '{module_path}' is installed and contains class '{class_name}'. " + f"Consider registering the step using @ProcessorStepRegistry.register() for better portability. " + f"Error: {str(e)}" + ) from e + + # Instantiate the step with its config + try: + saved_cfg = step_entry.get("config", {}) + step_overrides = overrides.get(step_key, {}) + merged_cfg = {**saved_cfg, **step_overrides} + step_instance: ProcessorStep = step_class(**merged_cfg) + + # Track which override keys were used + if step_key in override_keys: + override_keys.discard(step_key) + + except Exception as e: + step_name = step_entry.get("registry_name", step_entry.get("class", "Unknown")) + raise ValueError( + f"Failed to instantiate processor step '{step_name}' with config: {step_entry.get('config', {})}. " + f"Error: {str(e)}" + ) from e + + # Load state if available + if "state_file" in step_entry and hasattr(step_instance, "load_state_dict"): + if Path(source).is_dir(): + # Local path - read directly + state_path = str(base_path / step_entry["state_file"]) + else: + # Hugging Face Hub - download the state file + state_path = hf_hub_download( + source, + step_entry["state_file"], + repo_type="model", + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + ) + + step_instance.load_state_dict(load_file(state_path)) + + steps.append(step_instance) + + # Check for unused override keys + if override_keys: + available_keys = [] + for step_entry in loaded_config["steps"]: + if "registry_name" in step_entry: + available_keys.append(step_entry["registry_name"]) + else: + full_class_path = step_entry["class"] + class_name = full_class_path.rsplit(".", 1)[1] + available_keys.append(class_name) + + raise KeyError( + f"Override keys {list(override_keys)} do not match any step in the saved configuration. " + f"Available step keys: {available_keys}. " + f"Make sure override keys match exact step class names or registry names." + ) + + return cls(steps, loaded_config.get("name", "RobotProcessor")) + + def __len__(self) -> int: + """Return the number of steps in the processor.""" + return len(self.steps) + + def __getitem__(self, idx: int | slice) -> ProcessorStep | RobotProcessor: + """Indexing helper exposing underlying steps. + * ``int`` – returns the idx-th ProcessorStep. + * ``slice`` – returns a new RobotProcessor with the sliced steps. + """ + if isinstance(idx, slice): + return RobotProcessor(self.steps[idx], self.name) + return self.steps[idx] + + def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]): + """Attach fn to be executed before every processor step.""" + self.before_step_hooks.append(fn) + + def unregister_before_step_hook(self, fn: Callable[[int, EnvTransition], None]): + """Remove a previously registered before_step hook. + + Args: + fn: The exact function reference that was registered. Must be the same object. + + Raises: + ValueError: If the hook is not found in the registered hooks. + """ + try: + self.before_step_hooks.remove(fn) + except ValueError: + raise ValueError( + f"Hook {fn} not found in before_step_hooks. Make sure to pass the exact same function reference." + ) from None + + def register_after_step_hook(self, fn: Callable[[int, EnvTransition], None]): + """Attach fn to be executed after every processor step.""" + self.after_step_hooks.append(fn) + + def unregister_after_step_hook(self, fn: Callable[[int, EnvTransition], None]): + """Remove a previously registered after_step hook. + + Args: + fn: The exact function reference that was registered. Must be the same object. + + Raises: + ValueError: If the hook is not found in the registered hooks. + """ + try: + self.after_step_hooks.remove(fn) + except ValueError: + raise ValueError( + f"Hook {fn} not found in after_step_hooks. Make sure to pass the exact same function reference." + ) from None + + def reset(self): + """Clear state in every step that implements ``reset()`` and fire registered hooks.""" + for step in self.steps: + if hasattr(step, "reset"): + step.reset() # type: ignore[attr-defined] + + def __repr__(self) -> str: + """Return a readable string representation of the processor.""" + step_names = [step.__class__.__name__ for step in self.steps] + + if not step_names: + steps_repr = "steps=0: []" + elif len(step_names) <= 3: + steps_repr = f"steps={len(step_names)}: [{', '.join(step_names)}]" + else: + # Show first 2 and last 1 with ellipsis for long lists + displayed = f"{step_names[0]}, {step_names[1]}, ..., {step_names[-1]}" + steps_repr = f"steps={len(step_names)}: [{displayed}]" + + parts = [f"name='{self.name}'", steps_repr] + + return f"RobotProcessor({', '.join(parts)})" + + def __post_init__(self): + for i, step in enumerate(self.steps): + if not callable(step): + raise TypeError( + f"Step {i} ({type(step).__name__}) must define __call__(transition) -> EnvTransition" + ) + + fc = getattr(step, "feature_contract", None) + if not callable(fc): + raise TypeError( + f"Step {i} ({type(step).__name__}) must define feature_contract(features) -> dict[str, Any]" + ) + + def feature_contract(self, initial_features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """ + Apply ALL steps in order. Each step must implement + feature_contract(features) and return a dict (full or incremental schema). + """ + features: dict[str, PolicyFeature] = deepcopy(initial_features) + + for _, step in enumerate(self.steps): + out = step.feature_contract(features) + if not isinstance(out, dict): + raise TypeError(f"{step.__class__.__name__}.feature_contract must return dict[str, Any]") + features = out + return features + + +class ObservationProcessor: + """Base class for processors that modify only the observation component of a transition. + + Subclasses should override the `observation` method to implement custom observation processing. + This class handles the boilerplate of extracting and reinserting the processed observation + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class MyObservationScaler(ObservationProcessor): + def __init__(self, scale_factor): + self.scale_factor = scale_factor + + def observation(self, observation): + return observation * self.scale_factor + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition dict + manipulation, focusing only on the specific observation processing logic. + """ + + def observation(self, observation): + """Process the observation component. + + Args: + observation: The observation to process + + Returns: + The processed observation + """ + return observation + + def __call__(self, transition: EnvTransition) -> EnvTransition: + observation = transition.get(TransitionKey.OBSERVATION) + if observation is None: + return transition + + processed_observation = self.observation(observation) + # Create a new transition dict with the processed observation + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = processed_observation + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +class ActionProcessor: + """Base class for processors that modify only the action component of a transition. + + Subclasses should override the `action` method to implement custom action processing. + This class handles the boilerplate of extracting and reinserting the processed action + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class ActionClipping(ActionProcessor): + def __init__(self, min_val, max_val): + self.min_val = min_val + self.max_val = max_val + + def action(self, action): + return np.clip(action, self.min_val, self.max_val) + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition dict + manipulation, focusing only on the specific action processing logic. + """ + + def action(self, action): + """Process the action component. + + Args: + action: The action to process + + Returns: + The processed action + """ + return action + + def __call__(self, transition: EnvTransition) -> EnvTransition: + action = transition.get(TransitionKey.ACTION) + if action is None: + return transition + + processed_action = self.action(action) + # Create a new transition dict with the processed action + new_transition = transition.copy() + new_transition[TransitionKey.ACTION] = processed_action + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +class RewardProcessor: + """Base class for processors that modify only the reward component of a transition. + + Subclasses should override the `reward` method to implement custom reward processing. + This class handles the boilerplate of extracting and reinserting the processed reward + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class RewardScaler(RewardProcessor): + def __init__(self, scale_factor): + self.scale_factor = scale_factor + + def reward(self, reward): + return reward * self.scale_factor + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition dict + manipulation, focusing only on the specific reward processing logic. + """ + + def reward(self, reward): + """Process the reward component. + + Args: + reward: The reward to process + + Returns: + The processed reward + """ + return reward + + def __call__(self, transition: EnvTransition) -> EnvTransition: + reward = transition.get(TransitionKey.REWARD) + if reward is None: + return transition + + processed_reward = self.reward(reward) + # Create a new transition dict with the processed reward + new_transition = transition.copy() + new_transition[TransitionKey.REWARD] = processed_reward + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +class DoneProcessor: + """Base class for processors that modify only the done flag of a transition. + + Subclasses should override the `done` method to implement custom done flag processing. + This class handles the boilerplate of extracting and reinserting the processed done flag + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class TimeoutDone(DoneProcessor): + def __init__(self, max_steps): + self.steps = 0 + self.max_steps = max_steps + + def done(self, done): + self.steps += 1 + return done or self.steps >= self.max_steps + + def reset(self): + self.steps = 0 + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition dict + manipulation, focusing only on the specific done flag processing logic. + """ + + def done(self, done): + """Process the done flag. + + Args: + done: The done flag to process + + Returns: + The processed done flag + """ + return done + + def __call__(self, transition: EnvTransition) -> EnvTransition: + done = transition.get(TransitionKey.DONE) + if done is None: + return transition + + processed_done = self.done(done) + # Create a new transition dict with the processed done flag + new_transition = transition.copy() + new_transition[TransitionKey.DONE] = processed_done + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +class TruncatedProcessor: + """Base class for processors that modify only the truncated flag of a transition. + + Subclasses should override the `truncated` method to implement custom truncated flag processing. + This class handles the boilerplate of extracting and reinserting the processed truncated flag + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class EarlyTruncation(TruncatedProcessor): + def __init__(self, threshold): + self.threshold = threshold + + def truncated(self, truncated): + # Additional truncation condition + return truncated or some_condition > self.threshold + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition dict + manipulation, focusing only on the specific truncated flag processing logic. + """ + + def truncated(self, truncated): + """Process the truncated flag. + + Args: + truncated: The truncated flag to process + + Returns: + The processed truncated flag + """ + return truncated + + def __call__(self, transition: EnvTransition) -> EnvTransition: + truncated = transition.get(TransitionKey.TRUNCATED) + if truncated is None: + return transition + + processed_truncated = self.truncated(truncated) + # Create a new transition dict with the processed truncated flag + new_transition = transition.copy() + new_transition[TransitionKey.TRUNCATED] = processed_truncated + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +class InfoProcessor: + """Base class for processors that modify only the info dictionary of a transition. + + Subclasses should override the `info` method to implement custom info processing. + This class handles the boilerplate of extracting and reinserting the processed info + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + + Example: + ```python + class InfoAugmenter(InfoProcessor): + def __init__(self): + self.step_count = 0 + + def info(self, info): + info = info.copy() # Create a copy to avoid modifying the original + info["steps"] = self.step_count + self.step_count += 1 + return info + + def reset(self): + self.step_count = 0 + ``` + + By inheriting from this class, you avoid writing repetitive code to handle transition dict + manipulation, focusing only on the specific info dictionary processing logic. + """ + + def info(self, info): + """Process the info dictionary. + + Args: + info: The info dictionary to process + + Returns: + The processed info dictionary + """ + return info + + def __call__(self, transition: EnvTransition) -> EnvTransition: + info = transition.get(TransitionKey.INFO) + if info is None: + return transition + + processed_info = self.info(info) + # Create a new transition dict with the processed info + new_transition = transition.copy() + new_transition[TransitionKey.INFO] = processed_info + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +class ComplementaryDataProcessor: + """Base class for processors that modify only the complementary data of a transition. + + Subclasses should override the `complementary_data` method to implement custom complementary data processing. + This class handles the boilerplate of extracting and reinserting the processed complementary data + into the transition dict, eliminating the need to implement the `__call__` method in subclasses. + """ + + def complementary_data(self, complementary_data): + """Process the complementary data. + + Args: + complementary_data: The complementary data to process + + Returns: + The processed complementary data + """ + return complementary_data + + def __call__(self, transition: EnvTransition) -> EnvTransition: + complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA) + if complementary_data is None: + return transition + + processed_complementary_data = self.complementary_data(complementary_data) + # Create a new transition dict with the processed complementary data + new_transition = transition.copy() + new_transition[TransitionKey.COMPLEMENTARY_DATA] = processed_complementary_data + return new_transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +class IdentityProcessor: + """Identity processor that does nothing.""" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def get_config(self) -> dict[str, Any]: + return {} + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features diff --git a/src/lerobot/processor/rename_processor.py b/src/lerobot/processor/rename_processor.py new file mode 100644 index 000000000..4fe4105a5 --- /dev/null +++ b/src/lerobot/processor/rename_processor.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass, field +from typing import Any + +from lerobot.configs.types import PolicyFeature +from lerobot.processor.pipeline import ( + ObservationProcessor, + ProcessorStepRegistry, +) + + +@dataclass +@ProcessorStepRegistry.register(name="rename_processor") +class RenameProcessor(ObservationProcessor): + """Rename processor that renames keys in the observation.""" + + rename_map: dict[str, str] = field(default_factory=dict) + + def observation(self, observation): + processed_obs = {} + for key, value in observation.items(): + if key in self.rename_map: + processed_obs[self.rename_map[key]] = value + else: + processed_obs[key] = value + + return processed_obs + + def get_config(self) -> dict[str, Any]: + return {"rename_map": self.rename_map} + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + """Transforms: + - Each key in the observation that appears in `rename_map` is renamed to its value. + - Keys not in `rename_map` remain unchanged. + """ + return {self.rename_map.get(k, k): v for k, v in features.items()} diff --git a/src/lerobot/record.py b/src/lerobot/record.py index 904a1c7d0..6900ceb35 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -18,7 +18,7 @@ Records a dataset. Actions for the robot can be either generated by teleoperatio Example: ```shell -python -m lerobot.record \ +lerobot-record \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.cameras="{laptop: {type: opencv, camera_index: 0, width: 640, height: 480}}" \ @@ -36,7 +36,7 @@ python -m lerobot.record \ Example recording with bimanual so100: ```shell -python -m lerobot.record \ +lerobot-record \ --robot.type=bi_so100_follower \ --robot.left_arm_port=/dev/tty.usbmodem5A460851411 \ --robot.right_arm_port=/dev/tty.usbmodem5A460812391 \ @@ -395,5 +395,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset: return dataset -if __name__ == "__main__": +def main(): record() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/replay.py b/src/lerobot/replay.py index 2c1eafea0..603aa93ea 100644 --- a/src/lerobot/replay.py +++ b/src/lerobot/replay.py @@ -18,7 +18,7 @@ Replays the actions of an episode from a dataset on a robot. Examples: ```shell -python -m lerobot.replay \ +lerobot-replay \ --robot.type=so100_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.id=black \ @@ -28,7 +28,7 @@ python -m lerobot.replay \ Example replay with bimanual so100: ```shell -python -m lerobot.replay \ +lerobot-replay \ --robot.type=bi_so100_follower \ --robot.left_arm_port=/dev/tty.usbmodem5A460851411 \ --robot.right_arm_port=/dev/tty.usbmodem5A460812391 \ @@ -113,5 +113,9 @@ def replay(cfg: ReplayConfig): robot.disconnect() -if __name__ == "__main__": +def main(): replay() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/robots/viperx/README.md b/src/lerobot/robots/viperx/README.md index 4e90c99c7..bbc9f7223 100644 --- a/src/lerobot/robots/viperx/README.md +++ b/src/lerobot/robots/viperx/README.md @@ -63,7 +63,7 @@ python lerobot/scripts/control_robot.py \ --control.type=teleoperate ``` -By adding `--robot.max_relative_target=5`, we override the default value for `max_relative_target` defined in [`AlohaRobotConfig`](lerobot/robot_devices/robots/configs.py). It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot.max_relative_target=null` to the command line: +By adding `--robot.max_relative_target=5`, we override the default value for `max_relative_target` defined in [`ViperXConfig`](./config_viperx.py). It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot.max_relative_target=null` to the command line: ```bash python lerobot/scripts/control_robot.py \ @@ -141,10 +141,10 @@ python lerobot/scripts/control_robot.py \ ## Train a policy -To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: +To train a policy to control your robot, use the [`lerobot-train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --dataset.repo_id=${HF_USER}/aloha_test \ --policy.type=act \ --output_dir=outputs/train/act_aloha_test \ diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index 7c5aec48a..13d30c686 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -21,7 +21,7 @@ You want to evaluate a model from the hub (eg: https://huggingface.co/lerobot/di for 10 episodes. ``` -python -m lerobot.scripts.eval \ +lerobot-eval \ --policy.path=lerobot/diffusion_pusht \ --env.type=pusht \ --eval.batch_size=10 \ @@ -32,7 +32,7 @@ python -m lerobot.scripts.eval \ OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes. ``` -python -m lerobot.scripts.eval \ +lerobot-eval \ --policy.path=outputs/train/diffusion_pusht/checkpoints/005000/pretrained_model \ --env.type=pusht \ --eval.batch_size=10 \ @@ -501,6 +501,10 @@ def eval_main(cfg: EvalPipelineConfig): logging.info("End of eval") -if __name__ == "__main__": +def main(): init_logging() eval_main() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index f09d231a8..235352cd8 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -286,6 +286,10 @@ def train(cfg: TrainPipelineConfig): policy.push_model_to_hub(cfg) -if __name__ == "__main__": +def main(): init_logging() train() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/setup_motors.py b/src/lerobot/setup_motors.py index c54582a1d..c1d256c21 100644 --- a/src/lerobot/setup_motors.py +++ b/src/lerobot/setup_motors.py @@ -18,7 +18,7 @@ Helper to set motor ids and baudrate. Example: ```shell -python -m lerobot.setup_motors \ +lerobot-setup-motors \ --teleop.type=so100_leader \ --teleop.port=/dev/tty.usbmodem575E0031751 ``` @@ -80,5 +80,9 @@ def setup_motors(cfg: SetupConfig): device.setup_motors() -if __name__ == "__main__": +def main(): setup_motors() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/teleoperate.py b/src/lerobot/teleoperate.py index 9836f1393..e7be6967b 100644 --- a/src/lerobot/teleoperate.py +++ b/src/lerobot/teleoperate.py @@ -18,7 +18,7 @@ Simple script to control a robot from teleoperation. Example: ```shell -python -m lerobot.teleoperate \ +lerobot-teleoperate \ --robot.type=so101_follower \ --robot.port=/dev/tty.usbmodem58760431541 \ --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ @@ -32,7 +32,7 @@ python -m lerobot.teleoperate \ Example teleoperation with bimanual so100: ```shell -python -m lerobot.teleoperate \ +lerobot-teleoperate \ --robot.type=bi_so100_follower \ --robot.left_arm_port=/dev/tty.usbmodem5A460851411 \ --robot.right_arm_port=/dev/tty.usbmodem5A460812391 \ @@ -153,5 +153,9 @@ def teleoperate(cfg: TeleoperateConfig): robot.disconnect() -if __name__ == "__main__": +def main(): teleoperate() + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/templates/lerobot_modelcard_template.md b/src/lerobot/templates/lerobot_modelcard_template.md index 7b7aaa84a..9293d6ba7 100644 --- a/src/lerobot/templates/lerobot_modelcard_template.md +++ b/src/lerobot/templates/lerobot_modelcard_template.md @@ -44,7 +44,7 @@ Below is the short version on how to train and run inference/eval: ### Train from scratch ```bash -python -m lerobot.scripts.train \ +lerobot-train \ --dataset.repo_id=${HF_USER}/ \ --policy.type=act \ --output_dir=outputs/train/ \ @@ -59,7 +59,7 @@ _Writes checkpoints to `outputs/train//checkpoints/`._ ### Evaluate the policy/run inference ```bash -python -m lerobot.record \ +lerobot-record \ --robot.type=so100_follower \ --dataset.repo_id=/eval_ \ --policy.path=/ \ diff --git a/src/lerobot/utils/robot_utils.py b/src/lerobot/utils/robot_utils.py index e6c0cfe6d..8069b3662 100644 --- a/src/lerobot/utils/robot_utils.py +++ b/src/lerobot/utils/robot_utils.py @@ -17,10 +17,9 @@ import time def busy_wait(seconds): - if platform.system() == "Darwin": - # On Mac, `time.sleep` is not accurate and we need to use this while loop trick, + if platform.system() == "Darwin" or platform.system() == "Windows": + # On Mac and Windows, `time.sleep` is not accurate and we need to use this while loop trick, # but it consumes CPU cycles. - # TODO(rcadene): find an alternative: from python 11, time.sleep is precise end_time = time.perf_counter() + seconds while time.perf_counter() < end_time: pass diff --git a/tests/async_inference/test_robot_client.py b/tests/async_inference/test_robot_client.py index d1273ae63..51db2c3a7 100644 --- a/tests/async_inference/test_robot_client.py +++ b/tests/async_inference/test_robot_client.py @@ -13,7 +13,7 @@ # limitations under the License. """Unit-tests for the `RobotClient` action-queue logic (pure Python, no gRPC). -We monkey-patch `lerobot.common.robot_devices.robots.utils.make_robot` so that +We monkey-patch `lerobot.robots.utils.make_robot_from_config` so that no real hardware is accessed. Only the queue-update mechanism is verified. """ diff --git a/tests/conftest.py b/tests/conftest.py index 69dd3049b..7940cc5ba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,6 +19,7 @@ import traceback import pytest from serial import SerialException +from lerobot.configs.types import FeatureType, PolicyFeature from tests.utils import DEVICE # Import fixture modules as plugins @@ -69,3 +70,19 @@ def patch_builtins_input(monkeypatch): print(text) monkeypatch.setattr("builtins.input", print_text) + + +@pytest.fixture +def policy_feature_factory(): + """PolicyFeature factory""" + + def _pf(ft: FeatureType, shape: tuple[int, ...]) -> PolicyFeature: + return PolicyFeature(type=ft, shape=shape) + + return _pf + + +def assert_contract_is_typed(features: dict[str, PolicyFeature]) -> None: + assert isinstance(features, dict) + assert all(isinstance(k, str) for k in features.keys()) + assert all(isinstance(v, PolicyFeature) for v in features.values()) diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index ed37fedd6..da7573d7c 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -27,11 +27,13 @@ from lerobot import available_policies from lerobot.configs.default import DatasetConfig from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.constants import ACTION, OBS_STATE from lerobot.datasets.factory import make_dataset from lerobot.datasets.utils import cycle, dataset_to_policy_features from lerobot.envs.factory import make_env, make_env_config from lerobot.envs.utils import preprocess_observation from lerobot.optim.factory import make_optimizer_and_scheduler +from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.act.modeling_act import ACTTemporalEnsembler from lerobot.policies.factory import ( get_policy_class, @@ -363,6 +365,54 @@ def test_normalize(insert_temporal_dim): unnormalize(output_batch) +@pytest.mark.parametrize("multikey", [True, False]) +def test_multikey_construction(multikey: bool): + """ + Asserts that multiple keys with type State/Action are correctly processed by the policy constructor, + preventing erroneous creation of the policy object. + """ + input_features = { + "observation.state": PolicyFeature( + type=FeatureType.STATE, + shape=(10,), + ), + } + output_features = { + "action": PolicyFeature( + type=FeatureType.ACTION, + shape=(5,), + ), + } + + if multikey: + """Simulates the complete state/action is constructed from more granular multiple + keys, of the same type as the overall state/action""" + input_features = {} + input_features["observation.state.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) + input_features["observation.state.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,)) + input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(10,)) + + output_features = {} + output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,)) + output_features["action.last_two_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(2,)) + output_features["action"] = PolicyFeature( + type=FeatureType.ACTION, + shape=(5,), + ) + + config = ACTConfig(input_features=input_features, output_features=output_features) + + state_condition = config.robot_state_feature == input_features[OBS_STATE] + action_condition = config.action_feature == output_features[ACTION] + + assert state_condition, ( + f"Discrepancy detected. Robot state feature is {config.robot_state_feature} but policy expects {input_features[OBS_STATE]}" + ) + assert action_condition, ( + f"Discrepancy detected. Action feature is {config.action_feature} but policy expects {output_features[ACTION]}" + ) + + @pytest.mark.parametrize( "ds_repo_id, policy_name, policy_kwargs, file_name_extra", [ diff --git a/tests/processor/test_batch_conversion.py b/tests/processor/test_batch_conversion.py new file mode 100644 index 000000000..63894025d --- /dev/null +++ b/tests/processor/test_batch_conversion.py @@ -0,0 +1,282 @@ +import torch + +from lerobot.processor.pipeline import ( + RobotProcessor, + TransitionKey, + _default_batch_to_transition, + _default_transition_to_batch, +) + + +def _dummy_batch(): + """Create a dummy batch using the new format with observation.* and next.* keys.""" + return { + "observation.image.left": torch.randn(1, 3, 128, 128), + "observation.image.right": torch.randn(1, 3, 128, 128), + "observation.state": torch.tensor([[0.1, 0.2, 0.3, 0.4]]), + "action": torch.tensor([[0.5]]), + "next.reward": 1.0, + "next.done": False, + "next.truncated": False, + "info": {"key": "value"}, + } + + +def test_observation_grouping_roundtrip(): + """Test that observation.* keys are properly grouped and ungrouped.""" + proc = RobotProcessor([]) + batch_in = _dummy_batch() + batch_out = proc(batch_in) + + # Check that all observation.* keys are preserved + original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith("observation.")} + reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith("observation.")} + + assert set(original_obs_keys.keys()) == set(reconstructed_obs_keys.keys()) + + # Check tensor values + assert torch.allclose(batch_out["observation.image.left"], batch_in["observation.image.left"]) + assert torch.allclose(batch_out["observation.image.right"], batch_in["observation.image.right"]) + assert torch.allclose(batch_out["observation.state"], batch_in["observation.state"]) + + # Check other fields + assert torch.allclose(batch_out["action"], batch_in["action"]) + assert batch_out["next.reward"] == batch_in["next.reward"] + assert batch_out["next.done"] == batch_in["next.done"] + assert batch_out["next.truncated"] == batch_in["next.truncated"] + assert batch_out["info"] == batch_in["info"] + + +def test_batch_to_transition_observation_grouping(): + """Test that _default_batch_to_transition correctly groups observation.* keys.""" + batch = { + "observation.image.top": torch.randn(1, 3, 128, 128), + "observation.image.left": torch.randn(1, 3, 128, 128), + "observation.state": [1, 2, 3, 4], + "action": "action_data", + "next.reward": 1.5, + "next.done": True, + "next.truncated": False, + "info": {"episode": 42}, + } + + transition = _default_batch_to_transition(batch) + + # Check observation is a dict with all observation.* keys + assert isinstance(transition[TransitionKey.OBSERVATION], dict) + assert "observation.image.top" in transition[TransitionKey.OBSERVATION] + assert "observation.image.left" in transition[TransitionKey.OBSERVATION] + assert "observation.state" in transition[TransitionKey.OBSERVATION] + + # Check values are preserved + assert torch.allclose( + transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"] + ) + assert torch.allclose( + transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"] + ) + assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4] + + # Check other fields + assert transition[TransitionKey.ACTION] == "action_data" + assert transition[TransitionKey.REWARD] == 1.5 + assert transition[TransitionKey.DONE] + assert not transition[TransitionKey.TRUNCATED] + assert transition[TransitionKey.INFO] == {"episode": 42} + assert transition[TransitionKey.COMPLEMENTARY_DATA] == {} + + +def test_transition_to_batch_observation_flattening(): + """Test that _default_transition_to_batch correctly flattens observation dict.""" + observation_dict = { + "observation.image.top": torch.randn(1, 3, 128, 128), + "observation.image.left": torch.randn(1, 3, 128, 128), + "observation.state": [1, 2, 3, 4], + } + + transition = { + TransitionKey.OBSERVATION: observation_dict, + TransitionKey.ACTION: "action_data", + TransitionKey.REWARD: 1.5, + TransitionKey.DONE: True, + TransitionKey.TRUNCATED: False, + TransitionKey.INFO: {"episode": 42}, + TransitionKey.COMPLEMENTARY_DATA: {}, + } + + batch = _default_transition_to_batch(transition) + + # Check that observation.* keys are flattened back to batch + assert "observation.image.top" in batch + assert "observation.image.left" in batch + assert "observation.state" in batch + + # Check values are preserved + assert torch.allclose(batch["observation.image.top"], observation_dict["observation.image.top"]) + assert torch.allclose(batch["observation.image.left"], observation_dict["observation.image.left"]) + assert batch["observation.state"] == [1, 2, 3, 4] + + # Check other fields are mapped to next.* format + assert batch["action"] == "action_data" + assert batch["next.reward"] == 1.5 + assert batch["next.done"] + assert not batch["next.truncated"] + assert batch["info"] == {"episode": 42} + + +def test_no_observation_keys(): + """Test behavior when there are no observation.* keys.""" + batch = { + "action": "action_data", + "next.reward": 2.0, + "next.done": False, + "next.truncated": True, + "info": {"test": "no_obs"}, + } + + transition = _default_batch_to_transition(batch) + + # Observation should be None when no observation.* keys + assert transition[TransitionKey.OBSERVATION] is None + + # Check other fields + assert transition[TransitionKey.ACTION] == "action_data" + assert transition[TransitionKey.REWARD] == 2.0 + assert not transition[TransitionKey.DONE] + assert transition[TransitionKey.TRUNCATED] + assert transition[TransitionKey.INFO] == {"test": "no_obs"} + + # Round trip should work + reconstructed_batch = _default_transition_to_batch(transition) + assert reconstructed_batch["action"] == "action_data" + assert reconstructed_batch["next.reward"] == 2.0 + assert not reconstructed_batch["next.done"] + assert reconstructed_batch["next.truncated"] + assert reconstructed_batch["info"] == {"test": "no_obs"} + + +def test_minimal_batch(): + """Test with minimal batch containing only observation.* and action.""" + batch = {"observation.state": "minimal_state", "action": "minimal_action"} + + transition = _default_batch_to_transition(batch) + + # Check observation + assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"} + assert transition[TransitionKey.ACTION] == "minimal_action" + + # Check defaults + assert transition[TransitionKey.REWARD] == 0.0 + assert not transition[TransitionKey.DONE] + assert not transition[TransitionKey.TRUNCATED] + assert transition[TransitionKey.INFO] == {} + assert transition[TransitionKey.COMPLEMENTARY_DATA] == {} + + # Round trip + reconstructed_batch = _default_transition_to_batch(transition) + assert reconstructed_batch["observation.state"] == "minimal_state" + assert reconstructed_batch["action"] == "minimal_action" + assert reconstructed_batch["next.reward"] == 0.0 + assert not reconstructed_batch["next.done"] + assert not reconstructed_batch["next.truncated"] + assert reconstructed_batch["info"] == {} + + +def test_empty_batch(): + """Test behavior with empty batch.""" + batch = {} + + transition = _default_batch_to_transition(batch) + + # All fields should have defaults + assert transition[TransitionKey.OBSERVATION] is None + assert transition[TransitionKey.ACTION] is None + assert transition[TransitionKey.REWARD] == 0.0 + assert not transition[TransitionKey.DONE] + assert not transition[TransitionKey.TRUNCATED] + assert transition[TransitionKey.INFO] == {} + assert transition[TransitionKey.COMPLEMENTARY_DATA] == {} + + # Round trip + reconstructed_batch = _default_transition_to_batch(transition) + assert reconstructed_batch["action"] is None + assert reconstructed_batch["next.reward"] == 0.0 + assert not reconstructed_batch["next.done"] + assert not reconstructed_batch["next.truncated"] + assert reconstructed_batch["info"] == {} + + +def test_complex_nested_observation(): + """Test with complex nested observation data.""" + batch = { + "observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890}, + "observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891}, + "observation.state": torch.randn(7), + "action": torch.randn(8), + "next.reward": 3.14, + "next.done": False, + "next.truncated": True, + "info": {"episode_length": 200, "success": True}, + } + + transition = _default_batch_to_transition(batch) + reconstructed_batch = _default_transition_to_batch(transition) + + # Check that all observation keys are preserved + original_obs_keys = {k for k in batch if k.startswith("observation.")} + reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith("observation.")} + + assert original_obs_keys == reconstructed_obs_keys + + # Check tensor values + assert torch.allclose(batch["observation.state"], reconstructed_batch["observation.state"]) + + # Check nested dict with tensors + assert torch.allclose( + batch["observation.image.top"]["image"], reconstructed_batch["observation.image.top"]["image"] + ) + assert torch.allclose( + batch["observation.image.left"]["image"], reconstructed_batch["observation.image.left"]["image"] + ) + + # Check action tensor + assert torch.allclose(batch["action"], reconstructed_batch["action"]) + + # Check other fields + assert batch["next.reward"] == reconstructed_batch["next.reward"] + assert batch["next.done"] == reconstructed_batch["next.done"] + assert batch["next.truncated"] == reconstructed_batch["next.truncated"] + assert batch["info"] == reconstructed_batch["info"] + + +def test_custom_converter(): + """Test that custom converters can still be used.""" + + def to_tr(batch): + # Custom converter that modifies the reward + tr = _default_batch_to_transition(batch) + # Double the reward + reward = tr.get(TransitionKey.REWARD, 0.0) + new_tr = tr.copy() + new_tr[TransitionKey.REWARD] = reward * 2 if reward is not None else 0.0 + return new_tr + + def to_batch(tr): + batch = _default_transition_to_batch(tr) + return batch + + processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch) + + batch = { + "observation.state": torch.randn(1, 4), + "action": torch.randn(1, 2), + "next.reward": 1.0, + "next.done": False, + } + + result = processor(batch) + + # Check the reward was doubled by our custom converter + assert result["next.reward"] == 2.0 + assert torch.allclose(result["observation.state"], batch["observation.state"]) + assert torch.allclose(result["action"], batch["action"]) diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py new file mode 100644 index 000000000..26aea56c7 --- /dev/null +++ b/tests/processor/test_normalize_processor.py @@ -0,0 +1,628 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from unittest.mock import Mock + +import numpy as np +import pytest +import torch + +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.processor.normalize_processor import ( + NormalizerProcessor, + UnnormalizerProcessor, + _convert_stats_to_tensors, +) +from lerobot.processor.pipeline import RobotProcessor, TransitionKey + + +def create_transition( + observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None +): + """Helper to create an EnvTransition dictionary.""" + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info, + TransitionKey.COMPLEMENTARY_DATA: complementary_data, + } + + +def test_numpy_conversion(): + stats = { + "observation.image": { + "mean": np.array([0.5, 0.5, 0.5]), + "std": np.array([0.2, 0.2, 0.2]), + } + } + tensor_stats = _convert_stats_to_tensors(stats) + + assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor) + assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor) + assert torch.allclose(tensor_stats["observation.image"]["mean"], torch.tensor([0.5, 0.5, 0.5])) + assert torch.allclose(tensor_stats["observation.image"]["std"], torch.tensor([0.2, 0.2, 0.2])) + + +def test_tensor_conversion(): + stats = { + "action": { + "mean": torch.tensor([0.0, 0.0]), + "std": torch.tensor([1.0, 1.0]), + } + } + tensor_stats = _convert_stats_to_tensors(stats) + + assert tensor_stats["action"]["mean"].dtype == torch.float32 + assert tensor_stats["action"]["std"].dtype == torch.float32 + + +def test_scalar_conversion(): + stats = { + "reward": { + "mean": 0.5, + "std": 0.1, + } + } + tensor_stats = _convert_stats_to_tensors(stats) + + assert torch.allclose(tensor_stats["reward"]["mean"], torch.tensor(0.5)) + assert torch.allclose(tensor_stats["reward"]["std"], torch.tensor(0.1)) + + +def test_list_conversion(): + stats = { + "observation.state": { + "min": [0.0, -1.0, -2.0], + "max": [1.0, 1.0, 2.0], + } + } + tensor_stats = _convert_stats_to_tensors(stats) + + assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0])) + assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0])) + + +def test_unsupported_type(): + stats = { + "bad_key": { + "mean": "string_value", + } + } + with pytest.raises(TypeError, match="Unsupported type"): + _convert_stats_to_tensors(stats) + + +# Helper functions to create feature maps and norm maps +def _create_observation_features(): + return { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + } + + +def _create_observation_norm_map(): + return { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.STATE: NormalizationMode.MIN_MAX, + } + + +# Fixtures for observation normalisation tests using NormalizerProcessor +@pytest.fixture +def observation_stats(): + return { + "observation.image": { + "mean": np.array([0.5, 0.5, 0.5]), + "std": np.array([0.2, 0.2, 0.2]), + }, + "observation.state": { + "min": np.array([0.0, -1.0]), + "max": np.array([1.0, 1.0]), + }, + } + + +@pytest.fixture +def observation_normalizer(observation_stats): + """Return a NormalizerProcessor that only has observation stats (no action).""" + features = _create_observation_features() + norm_map = _create_observation_norm_map() + return NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats) + + +def test_mean_std_normalization(observation_normalizer): + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = create_transition(observation=observation) + + normalized_transition = observation_normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Check mean/std normalization + expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 + assert torch.allclose(normalized_obs["observation.image"], expected_image) + + +def test_min_max_normalization(observation_normalizer): + observation = { + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = create_transition(observation=observation) + + normalized_transition = observation_normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Check min/max normalization to [-1, 1] + # For state[0]: 2 * (0.5 - 0.0) / (1.0 - 0.0) - 1 = 0.0 + # For state[1]: 2 * (0.0 - (-1.0)) / (1.0 - (-1.0)) - 1 = 0.0 + expected_state = torch.tensor([0.0, 0.0]) + assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6) + + +def test_selective_normalization(observation_stats): + features = _create_observation_features() + norm_map = _create_observation_norm_map() + normalizer = NormalizerProcessor( + features=features, norm_map=norm_map, stats=observation_stats, normalize_keys={"observation.image"} + ) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + transition = create_transition(observation=observation) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + # Only image should be normalized + assert torch.allclose(normalized_obs["observation.image"], (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2) + # State should remain unchanged + assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"]) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_device_compatibility(observation_stats): + features = _create_observation_features() + norm_map = _create_observation_norm_map() + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats) + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(), + } + transition = create_transition(observation=observation) + + normalized_transition = normalizer(transition) + normalized_obs = normalized_transition[TransitionKey.OBSERVATION] + + assert normalized_obs["observation.image"].device.type == "cuda" + + +def test_from_lerobot_dataset(): + # Mock dataset + mock_dataset = Mock() + mock_dataset.meta.stats = { + "observation.image": {"mean": [0.5], "std": [0.2]}, + "action": {"mean": [0.0], "std": [1.0]}, + } + + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "action": PolicyFeature(FeatureType.ACTION, (1,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + normalizer = NormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) + + # Both observation and action statistics should be present in tensor stats + assert "observation.image" in normalizer._tensor_stats + assert "action" in normalizer._tensor_stats + + +def test_state_dict_save_load(observation_normalizer): + # Save state + state_dict = observation_normalizer.state_dict() + + # Create new normalizer and load state + features = _create_observation_features() + norm_map = _create_observation_norm_map() + new_normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={}) + new_normalizer.load_state_dict(state_dict) + + # Test that it works the same + observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])} + transition = create_transition(observation=observation) + + result1 = observation_normalizer(transition)[TransitionKey.OBSERVATION] + result2 = new_normalizer(transition)[TransitionKey.OBSERVATION] + + assert torch.allclose(result1["observation.image"], result2["observation.image"]) + + +# Fixtures for ActionUnnormalizer tests +@pytest.fixture +def action_stats_mean_std(): + return { + "mean": np.array([0.0, 0.0, 0.0]), + "std": np.array([1.0, 2.0, 0.5]), + } + + +@pytest.fixture +def action_stats_min_max(): + return { + "min": np.array([-1.0, -2.0, 0.0]), + "max": np.array([1.0, 2.0, 1.0]), + } + + +def _create_action_features(): + return { + "action": PolicyFeature(FeatureType.ACTION, (3,)), + } + + +def _create_action_norm_map_mean_std(): + return { + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + +def _create_action_norm_map_min_max(): + return { + FeatureType.ACTION: NormalizationMode.MIN_MAX, + } + + +def test_mean_std_unnormalization(action_stats_mean_std): + features = _create_action_features() + norm_map = _create_action_norm_map_mean_std() + unnormalizer = UnnormalizerProcessor( + features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} + ) + + normalized_action = torch.tensor([1.0, -0.5, 2.0]) + transition = create_transition(action=normalized_action) + + unnormalized_transition = unnormalizer(transition) + unnormalized_action = unnormalized_transition[TransitionKey.ACTION] + + # action * std + mean + expected = torch.tensor([1.0 * 1.0 + 0.0, -0.5 * 2.0 + 0.0, 2.0 * 0.5 + 0.0]) + assert torch.allclose(unnormalized_action, expected) + + +def test_min_max_unnormalization(action_stats_min_max): + features = _create_action_features() + norm_map = _create_action_norm_map_min_max() + unnormalizer = UnnormalizerProcessor( + features=features, norm_map=norm_map, stats={"action": action_stats_min_max} + ) + + # Actions in [-1, 1] + normalized_action = torch.tensor([0.0, -1.0, 1.0]) + transition = create_transition(action=normalized_action) + + unnormalized_transition = unnormalizer(transition) + unnormalized_action = unnormalized_transition[TransitionKey.ACTION] + + # Map from [-1, 1] to [min, max] + # (action + 1) / 2 * (max - min) + min + expected = torch.tensor( + [ + (0.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0), # 0.0 + (-1.0 + 1) / 2 * (2.0 - (-2.0)) + (-2.0), # -2.0 + (1.0 + 1) / 2 * (1.0 - 0.0) + 0.0, # 1.0 + ] + ) + assert torch.allclose(unnormalized_action, expected) + + +def test_numpy_action_input(action_stats_mean_std): + features = _create_action_features() + norm_map = _create_action_norm_map_mean_std() + unnormalizer = UnnormalizerProcessor( + features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} + ) + + normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32) + transition = create_transition(action=normalized_action) + + unnormalized_transition = unnormalizer(transition) + unnormalized_action = unnormalized_transition[TransitionKey.ACTION] + + assert isinstance(unnormalized_action, torch.Tensor) + expected = torch.tensor([1.0, -1.0, 1.0]) + assert torch.allclose(unnormalized_action, expected) + + +def test_none_action(action_stats_mean_std): + features = _create_action_features() + norm_map = _create_action_norm_map_mean_std() + unnormalizer = UnnormalizerProcessor( + features=features, norm_map=norm_map, stats={"action": action_stats_mean_std} + ) + + transition = create_transition() + result = unnormalizer(transition) + + # Should return transition unchanged + assert result == transition + + +def test_action_from_lerobot_dataset(): + mock_dataset = Mock() + mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}} + features = {"action": PolicyFeature(FeatureType.ACTION, (1,))} + norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD} + unnormalizer = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) + assert "mean" in unnormalizer._tensor_stats["action"] + + +# Fixtures for NormalizerProcessor tests +@pytest.fixture +def full_stats(): + return { + "observation.image": { + "mean": np.array([0.5, 0.5, 0.5]), + "std": np.array([0.2, 0.2, 0.2]), + }, + "observation.state": { + "min": np.array([0.0, -1.0]), + "max": np.array([1.0, 1.0]), + }, + "action": { + "mean": np.array([0.0, 0.0]), + "std": np.array([1.0, 2.0]), + }, + } + + +def _create_full_features(): + return { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + + +def _create_full_norm_map(): + return { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.STATE: NormalizationMode.MIN_MAX, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + + +@pytest.fixture +def normalizer_processor(full_stats): + features = _create_full_features() + norm_map = _create_full_norm_map() + return NormalizerProcessor(features=features, norm_map=norm_map, stats=full_stats) + + +def test_combined_normalization(normalizer_processor): + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + action = torch.tensor([1.0, -0.5]) + transition = create_transition( + observation=observation, + action=action, + reward=1.0, + done=False, + truncated=False, + info={}, + complementary_data={}, + ) + + processed_transition = normalizer_processor(transition) + + # Check normalized observations + processed_obs = processed_transition[TransitionKey.OBSERVATION] + expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2 + assert torch.allclose(processed_obs["observation.image"], expected_image) + + # Check normalized action + processed_action = processed_transition[TransitionKey.ACTION] + expected_action = torch.tensor([(1.0 - 0.0) / 1.0, (-0.5 - 0.0) / 2.0]) + assert torch.allclose(processed_action, expected_action) + + # Check other fields remain unchanged + assert processed_transition[TransitionKey.REWARD] == 1.0 + assert not processed_transition[TransitionKey.DONE] + + +def test_processor_from_lerobot_dataset(full_stats): + # Mock dataset + mock_dataset = Mock() + mock_dataset.meta.stats = full_stats + + features = _create_full_features() + norm_map = _create_full_norm_map() + + processor = NormalizerProcessor.from_lerobot_dataset( + mock_dataset, features, norm_map, normalize_keys={"observation.image"} + ) + + assert processor.normalize_keys == {"observation.image"} + assert "observation.image" in processor._tensor_stats + assert "action" in processor._tensor_stats + + +def test_get_config(full_stats): + features = _create_full_features() + norm_map = _create_full_norm_map() + processor = NormalizerProcessor( + features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6 + ) + + config = processor.get_config() + expected_config = { + "normalize_keys": ["observation.image"], + "eps": 1e-6, + "features": { + "observation.image": {"type": "VISUAL", "shape": (3, 96, 96)}, + "observation.state": {"type": "STATE", "shape": (2,)}, + "action": {"type": "ACTION", "shape": (2,)}, + }, + "norm_map": { + "VISUAL": "MEAN_STD", + "STATE": "MIN_MAX", + "ACTION": "MEAN_STD", + }, + } + assert config == expected_config + + +def test_integration_with_robot_processor(normalizer_processor): + """Test integration with RobotProcessor pipeline""" + robot_processor = RobotProcessor([normalizer_processor]) + + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + action = torch.tensor([1.0, -0.5]) + transition = create_transition( + observation=observation, + action=action, + reward=1.0, + done=False, + truncated=False, + info={}, + complementary_data={}, + ) + + processed_transition = robot_processor(transition) + + # Verify the processing worked + assert isinstance(processed_transition[TransitionKey.OBSERVATION], dict) + assert isinstance(processed_transition[TransitionKey.ACTION], torch.Tensor) + + +# Edge case tests +def test_empty_observation(): + stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} + features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + + transition = create_transition() + result = normalizer(transition) + + assert result == transition + + +def test_empty_stats(): + features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={}) + observation = {"observation.image": torch.tensor([0.5])} + transition = create_transition(observation=observation) + + result = normalizer(transition) + # Should return observation unchanged since no stats are available + assert torch.allclose( + result[TransitionKey.OBSERVATION]["observation.image"], observation["observation.image"] + ) + + +def test_partial_stats(): + """If statistics are incomplete, the value should pass through unchanged.""" + stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max) + features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats) + observation = {"observation.image": torch.tensor([0.7])} + transition = create_transition(observation=observation) + + processed = normalizer(transition)[TransitionKey.OBSERVATION] + assert torch.allclose(processed["observation.image"], observation["observation.image"]) + + +def test_missing_action_stats_no_error(): + mock_dataset = Mock() + mock_dataset.meta.stats = {"observation.image": {"mean": [0.5], "std": [0.2]}} + + features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))} + norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD} + + processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map) + # The tensor stats should not contain the 'action' key + assert "action" not in processor._tensor_stats + + +def test_serialization_roundtrip(full_stats): + """Test that features and norm_map can be serialized and deserialized correctly.""" + features = _create_full_features() + norm_map = _create_full_norm_map() + original_processor = NormalizerProcessor( + features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6 + ) + + # Get config (serialization) + config = original_processor.get_config() + + # Create a new processor from the config (deserialization) + new_processor = NormalizerProcessor( + features=config["features"], + norm_map=config["norm_map"], + stats=full_stats, + normalize_keys=set(config["normalize_keys"]), + eps=config["eps"], + ) + + # Test that both processors work the same way + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + action = torch.tensor([1.0, -0.5]) + transition = create_transition( + observation=observation, + action=action, + reward=1.0, + done=False, + truncated=False, + info={}, + complementary_data={}, + ) + + result1 = original_processor(transition) + result2 = new_processor(transition) + + # Compare results + assert torch.allclose( + result1[TransitionKey.OBSERVATION]["observation.image"], + result2[TransitionKey.OBSERVATION]["observation.image"], + ) + assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION]) + + # Verify features and norm_map are correctly reconstructed + assert new_processor.features.keys() == original_processor.features.keys() + for key in new_processor.features: + assert new_processor.features[key].type == original_processor.features[key].type + assert new_processor.features[key].shape == original_processor.features[key].shape + + assert new_processor.norm_map == original_processor.norm_map diff --git a/tests/processor/test_observation_processor.py b/tests/processor/test_observation_processor.py new file mode 100644 index 000000000..e48b6bc08 --- /dev/null +++ b/tests/processor/test_observation_processor.py @@ -0,0 +1,486 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +import torch + +from lerobot.configs.types import FeatureType +from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE +from lerobot.processor import VanillaObservationProcessor +from lerobot.processor.pipeline import TransitionKey +from tests.conftest import assert_contract_is_typed + + +def create_transition( + observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None +): + """Helper to create an EnvTransition dictionary.""" + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info, + TransitionKey.COMPLEMENTARY_DATA: complementary_data, + } + + +def test_process_single_image(): + """Test processing a single image.""" + processor = VanillaObservationProcessor() + + # Create a mock image (H, W, C) format, uint8 + image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) + + observation = {"pixels": image} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that the image was processed correctly + assert "observation.image" in processed_obs + processed_img = processed_obs["observation.image"] + + # Check shape: should be (1, 3, 64, 64) - batch, channels, height, width + assert processed_img.shape == (1, 3, 64, 64) + + # Check dtype and range + assert processed_img.dtype == torch.float32 + assert processed_img.min() >= 0.0 + assert processed_img.max() <= 1.0 + + +def test_process_image_dict(): + """Test processing multiple images in a dictionary.""" + processor = VanillaObservationProcessor() + + # Create mock images + image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) + image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8) + + observation = {"pixels": {"camera1": image1, "camera2": image2}} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that both images were processed + assert "observation.images.camera1" in processed_obs + assert "observation.images.camera2" in processed_obs + + # Check shapes + assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32) + assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48) + + +def test_process_batched_image(): + """Test processing already batched images.""" + processor = VanillaObservationProcessor() + + # Create a batched image (B, H, W, C) + image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8) + + observation = {"pixels": image} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that batch dimension is preserved + assert processed_obs["observation.image"].shape == (2, 3, 64, 64) + + +def test_invalid_image_format(): + """Test error handling for invalid image formats.""" + processor = VanillaObservationProcessor() + + # Test wrong channel order (channels first) + image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8) + observation = {"pixels": image} + transition = create_transition(observation=observation) + + with pytest.raises(ValueError, match="Expected channel-last images"): + processor(transition) + + +def test_invalid_image_dtype(): + """Test error handling for invalid image dtype.""" + processor = VanillaObservationProcessor() + + # Test wrong dtype + image = np.random.rand(64, 64, 3).astype(np.float32) + observation = {"pixels": image} + transition = create_transition(observation=observation) + + with pytest.raises(ValueError, match="Expected torch.uint8 images"): + processor(transition) + + +def test_no_pixels_in_observation(): + """Test processor when no pixels are in observation.""" + processor = VanillaObservationProcessor() + + observation = {"other_data": np.array([1, 2, 3])} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Should preserve other data unchanged + assert "other_data" in processed_obs + np.testing.assert_array_equal(processed_obs["other_data"], np.array([1, 2, 3])) + + +def test_none_observation(): + """Test processor with None observation.""" + processor = VanillaObservationProcessor() + + transition = create_transition() + result = processor(transition) + + assert result == transition + + +def test_serialization_methods(): + """Test serialization methods.""" + processor = VanillaObservationProcessor() + + # Test get_config + config = processor.get_config() + assert isinstance(config, dict) + + # Test state_dict + state = processor.state_dict() + assert isinstance(state, dict) + + # Test load_state_dict (should not raise) + processor.load_state_dict(state) + + # Test reset (should not raise) + processor.reset() + + +def test_process_environment_state(): + """Test processing environment_state.""" + processor = VanillaObservationProcessor() + + env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) + observation = {"environment_state": env_state} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that environment_state was renamed and processed + assert "observation.environment_state" in processed_obs + assert "environment_state" not in processed_obs + + processed_state = processed_obs["observation.environment_state"] + assert processed_state.shape == (1, 3) # Batch dimension added + assert processed_state.dtype == torch.float32 + torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]])) + + +def test_process_agent_pos(): + """Test processing agent_pos.""" + processor = VanillaObservationProcessor() + + agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) + observation = {"agent_pos": agent_pos} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that agent_pos was renamed and processed + assert "observation.state" in processed_obs + assert "agent_pos" not in processed_obs + + processed_state = processed_obs["observation.state"] + assert processed_state.shape == (1, 3) # Batch dimension added + assert processed_state.dtype == torch.float32 + torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]])) + + +def test_process_batched_states(): + """Test processing already batched states.""" + processor = VanillaObservationProcessor() + + env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32) + + observation = {"environment_state": env_state, "agent_pos": agent_pos} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that batch dimensions are preserved + assert processed_obs["observation.environment_state"].shape == (2, 2) + assert processed_obs["observation.state"].shape == (2, 2) + + +def test_process_both_states(): + """Test processing both environment_state and agent_pos.""" + processor = VanillaObservationProcessor() + + env_state = np.array([1.0, 2.0], dtype=np.float32) + agent_pos = np.array([0.5, -0.5], dtype=np.float32) + + observation = {"environment_state": env_state, "agent_pos": agent_pos, "other_data": "keep_me"} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that both states were processed + assert "observation.environment_state" in processed_obs + assert "observation.state" in processed_obs + + # Check that original keys were removed + assert "environment_state" not in processed_obs + assert "agent_pos" not in processed_obs + + # Check that other data was preserved + assert processed_obs["other_data"] == "keep_me" + + +def test_no_states_in_observation(): + """Test processor when no states are in observation.""" + processor = VanillaObservationProcessor() + + observation = {"other_data": np.array([1, 2, 3])} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Should preserve data unchanged + np.testing.assert_array_equal(processed_obs, observation) + + +def test_complete_observation_processing(): + """Test processing a complete observation with both images and states.""" + processor = VanillaObservationProcessor() + + # Create mock data + image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) + env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) + agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) + + observation = { + "pixels": image, + "environment_state": env_state, + "agent_pos": agent_pos, + "other_data": "preserve_me", + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that image was processed + assert "observation.image" in processed_obs + assert processed_obs["observation.image"].shape == (1, 3, 32, 32) + + # Check that states were processed + assert "observation.environment_state" in processed_obs + assert "observation.state" in processed_obs + + # Check that original keys were removed + assert "pixels" not in processed_obs + assert "environment_state" not in processed_obs + assert "agent_pos" not in processed_obs + + # Check that other data was preserved + assert processed_obs["other_data"] == "preserve_me" + + +def test_image_only_processing(): + """Test processing observation with only images.""" + processor = VanillaObservationProcessor() + + image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) + observation = {"pixels": image} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + assert "observation.image" in processed_obs + assert len(processed_obs) == 1 + + +def test_state_only_processing(): + """Test processing observation with only states.""" + processor = VanillaObservationProcessor() + + agent_pos = np.array([1.0, 2.0], dtype=np.float32) + observation = {"agent_pos": agent_pos} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + assert "observation.state" in processed_obs + assert "agent_pos" not in processed_obs + + +def test_empty_observation(): + """Test processing empty observation.""" + processor = VanillaObservationProcessor() + + observation = {} + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + assert processed_obs == {} + + +def test_equivalent_to_original_function(): + """Test that ObservationProcessor produces equivalent results to preprocess_observation.""" + # Import the original function for comparison + from lerobot.envs.utils import preprocess_observation + + processor = VanillaObservationProcessor() + + # Create test data similar to what the original function expects + image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8) + env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32) + agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32) + + observation = {"pixels": image, "environment_state": env_state, "agent_pos": agent_pos} + + # Process with original function + original_result = preprocess_observation(observation) + + # Process with new processor + transition = create_transition(observation=observation) + processor_result = processor(transition)[TransitionKey.OBSERVATION] + + # Compare results + assert set(original_result.keys()) == set(processor_result.keys()) + + for key in original_result: + torch.testing.assert_close(original_result[key], processor_result[key]) + + +def test_equivalent_with_image_dict(): + """Test equivalence with dictionary of images.""" + from lerobot.envs.utils import preprocess_observation + + processor = VanillaObservationProcessor() + + # Create test data with multiple cameras + image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8) + image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8) + agent_pos = np.array([1.0, 2.0], dtype=np.float32) + + observation = {"pixels": {"cam1": image1, "cam2": image2}, "agent_pos": agent_pos} + + # Process with original function + original_result = preprocess_observation(observation) + + # Process with new processor + transition = create_transition(observation=observation) + processor_result = processor(transition)[TransitionKey.OBSERVATION] + + # Compare results + assert set(original_result.keys()) == set(processor_result.keys()) + + for key in original_result: + torch.testing.assert_close(original_result[key], processor_result[key]) + + +def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory): + processor = VanillaObservationProcessor() + features = { + "pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "keep": policy_feature_factory(FeatureType.ENV, (1,)), + } + out = processor.feature_contract(features.copy()) + + assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"] + assert "pixels" not in out + assert out["keep"] == features["keep"] + assert_contract_is_typed(out) + + +def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory): + processor = VanillaObservationProcessor() + features = { + "observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "keep": policy_feature_factory(FeatureType.ENV, (1,)), + } + out = processor.feature_contract(features.copy()) + + assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"] + assert "observation.pixels" not in out + assert out["keep"] == features["keep"] + assert_contract_is_typed(out) + + +def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory): + processor = VanillaObservationProcessor() + features = { + "pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "keep": policy_feature_factory(FeatureType.ENV, (7,)), + } + out = processor.feature_contract(features.copy()) + + assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"] + assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"] + assert f"{OBS_IMAGES}.rear" in out and out[f"{OBS_IMAGES}.rear"] == features["observation.pixels.rear"] + assert "pixels.front" not in out and "pixels.wrist" not in out and "observation.pixels.rear" not in out + assert out["keep"] == features["keep"] + assert_contract_is_typed(out) + + +def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory): + processor = VanillaObservationProcessor() + features = { + "environment_state": policy_feature_factory(FeatureType.STATE, (3,)), + "agent_pos": policy_feature_factory(FeatureType.STATE, (7,)), + "keep": policy_feature_factory(FeatureType.ENV, (1,)), + } + out = processor.feature_contract(features.copy()) + + assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"] + assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"] + assert "environment_state" not in out and "agent_pos" not in out + assert out["keep"] == features["keep"] + assert_contract_is_typed(out) + + +def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory): + proc = VanillaObservationProcessor() + features = { + "observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)), + "observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)), + } + out = proc.feature_contract(features.copy()) + + assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"] + assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"] + assert "environment_state" not in out and "agent_pos" not in out + assert_contract_is_typed(out) diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py new file mode 100644 index 000000000..5665d5a7d --- /dev/null +++ b/tests/processor/test_pipeline.py @@ -0,0 +1,1919 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import tempfile +from collections.abc import Callable +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import pytest +import torch +import torch.nn as nn + +from lerobot.configs.types import FeatureType, PolicyFeature +from lerobot.processor import EnvTransition, ProcessorStepRegistry, RobotProcessor +from lerobot.processor.pipeline import TransitionKey +from tests.conftest import assert_contract_is_typed + + +def create_transition( + observation=None, action=None, reward=0.0, done=False, truncated=False, info=None, complementary_data=None +): + """Helper to create an EnvTransition dictionary.""" + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info if info is not None else {}, + TransitionKey.COMPLEMENTARY_DATA: complementary_data if complementary_data is not None else {}, + } + + +@dataclass +class MockStep: + """Mock pipeline step for testing - demonstrates best practices. + + This example shows the proper separation: + - JSON-serializable attributes (name, counter) go in get_config() + - Only torch tensors go in state_dict() + + Note: The counter is part of the configuration, so it will be restored + when the step is recreated from config during loading. + """ + + name: str = "mock_step" + counter: int = 0 + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Add a counter to the complementary_data.""" + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + comp_data = {} if comp_data is None else dict(comp_data) # Make a copy + + comp_data[f"{self.name}_counter"] = self.counter + self.counter += 1 + + # Create a new transition with updated complementary_data + new_transition = transition.copy() + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + return new_transition + + def get_config(self) -> dict[str, Any]: + # Return all JSON-serializable attributes that should be persisted + # These will be passed to __init__ when loading + return {"name": self.name, "counter": self.counter} + + def state_dict(self) -> dict[str, torch.Tensor]: + # Only return torch tensors (empty in this case since we have no tensor state) + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + # No tensor state to load + pass + + def reset(self) -> None: + self.counter = 0 + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + +@dataclass +class MockStepWithoutOptionalMethods: + """Mock step that only implements the required __call__ method.""" + + multiplier: float = 2.0 + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Multiply reward by multiplier.""" + reward = transition.get(TransitionKey.REWARD) + + if reward is not None: + new_transition = transition.copy() + new_transition[TransitionKey.REWARD] = reward * self.multiplier + return new_transition + + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + +@dataclass +class MockStepWithTensorState: + """Mock step demonstrating mixed JSON attributes and tensor state.""" + + name: str = "tensor_step" + learning_rate: float = 0.01 + window_size: int = 10 + + def __init__(self, name: str = "tensor_step", learning_rate: float = 0.01, window_size: int = 10): + self.name = name + self.learning_rate = learning_rate + self.window_size = window_size + # Tensor state + self.running_mean = torch.zeros(window_size) + self.running_count = torch.tensor(0) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Update running statistics.""" + reward = transition.get(TransitionKey.REWARD) + + if reward is not None: + # Update running mean + idx = self.running_count % self.window_size + self.running_mean[idx] = reward + self.running_count += 1 + + return transition + + def get_config(self) -> dict[str, Any]: + # Only JSON-serializable attributes + return { + "name": self.name, + "learning_rate": self.learning_rate, + "window_size": self.window_size, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + # Only tensor state + return { + "running_mean": self.running_mean, + "running_count": self.running_count, + } + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + self.running_mean = state["running_mean"] + self.running_count = state["running_count"] + + def reset(self) -> None: + self.running_mean.zero_() + self.running_count.zero_() + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + +def test_empty_pipeline(): + """Test pipeline with no steps.""" + pipeline = RobotProcessor() + + transition = create_transition() + result = pipeline(transition) + + assert result == transition + assert len(pipeline) == 0 + + +def test_single_step_pipeline(): + """Test pipeline with a single step.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step]) + + transition = create_transition() + result = pipeline(transition) + + assert len(pipeline) == 1 + assert result[TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 0 + + # Call again to test counter increment + result = pipeline(transition) + assert result[TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 1 + + +def test_multiple_steps_pipeline(): + """Test pipeline with multiple steps.""" + step1 = MockStep("step1") + step2 = MockStep("step2") + pipeline = RobotProcessor([step1, step2]) + + transition = create_transition() + result = pipeline(transition) + + assert len(pipeline) == 2 + assert result[TransitionKey.COMPLEMENTARY_DATA]["step1_counter"] == 0 + assert result[TransitionKey.COMPLEMENTARY_DATA]["step2_counter"] == 0 + + +def test_invalid_transition_format(): + """Test pipeline with invalid transition format.""" + pipeline = RobotProcessor([MockStep()]) + + # Test with wrong type (tuple instead of dict) + with pytest.raises(ValueError, match="EnvTransition must be a dictionary"): + pipeline((None, None, 0.0, False, False, {}, {})) # Tuple instead of dict + + # Test with wrong type (string) + with pytest.raises(ValueError, match="EnvTransition must be a dictionary"): + pipeline("not a dict") + + +def test_step_through(): + """Test step_through method with dict input.""" + step1 = MockStep("step1") + step2 = MockStep("step2") + pipeline = RobotProcessor([step1, step2]) + + transition = create_transition() + + results = list(pipeline.step_through(transition)) + + assert len(results) == 3 # Original + 2 steps + assert results[0] == transition # Original + assert "step1_counter" in results[1][TransitionKey.COMPLEMENTARY_DATA] # After step1 + assert "step2_counter" in results[2][TransitionKey.COMPLEMENTARY_DATA] # After step2 + + # Ensure all results are dicts (same format as input) + for result in results: + assert isinstance(result, dict) + assert all(isinstance(k, TransitionKey) for k in result.keys()) + + +def test_step_through_with_dict(): + """Test step_through method with dict input.""" + step1 = MockStep("step1") + step2 = MockStep("step2") + pipeline = RobotProcessor([step1, step2]) + + batch = { + "observation.image": None, + "action": None, + "next.reward": 0.0, + "next.done": False, + "next.truncated": False, + "info": {}, + } + + results = list(pipeline.step_through(batch)) + + assert len(results) == 3 # Original + 2 steps + + # Ensure all results are EnvTransition dicts (regardless of input format) + for result in results: + assert isinstance(result, dict) + # Check that keys are TransitionKey enums or at least valid transition keys + for key in result: + assert key in [ + TransitionKey.OBSERVATION, + TransitionKey.ACTION, + TransitionKey.REWARD, + TransitionKey.DONE, + TransitionKey.TRUNCATED, + TransitionKey.INFO, + TransitionKey.COMPLEMENTARY_DATA, + ] + + # Check that the processing worked - verify step counters in complementary_data + assert results[1].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step1_counter") == 0 + assert results[2].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step1_counter") == 0 + assert results[2].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step2_counter") == 0 + + +def test_step_through_no_hooks(): + """Test that step_through doesn't execute hooks.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step]) + + hook_calls = [] + + def tracking_hook(idx: int, transition: EnvTransition): + hook_calls.append(f"hook_called_step_{idx}") + + # Register hooks + pipeline.register_before_step_hook(tracking_hook) + pipeline.register_after_step_hook(tracking_hook) + + # Use step_through + transition = create_transition() + results = list(pipeline.step_through(transition)) + + # Verify step was executed (counter should increment) + assert len(results) == 2 # Initial + 1 step + assert results[1][TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 0 + + # Verify hooks were NOT called + assert len(hook_calls) == 0 + + # Now use __call__ to verify hooks ARE called there + hook_calls.clear() + pipeline(transition) + + # Verify hooks were called (before and after for 1 step = 2 calls) + assert len(hook_calls) == 2 + assert hook_calls == ["hook_called_step_0", "hook_called_step_0"] + + +def test_indexing(): + """Test pipeline indexing.""" + step1 = MockStep("step1") + step2 = MockStep("step2") + pipeline = RobotProcessor([step1, step2]) + + # Test integer indexing + assert pipeline[0] is step1 + assert pipeline[1] is step2 + + # Test slice indexing + sub_pipeline = pipeline[0:1] + assert isinstance(sub_pipeline, RobotProcessor) + assert len(sub_pipeline) == 1 + assert sub_pipeline[0] is step1 + + +def test_hooks(): + """Test before/after step hooks.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step]) + + before_calls = [] + after_calls = [] + + def before_hook(idx: int, transition: EnvTransition): + before_calls.append(idx) + + def after_hook(idx: int, transition: EnvTransition): + after_calls.append(idx) + + pipeline.register_before_step_hook(before_hook) + pipeline.register_after_step_hook(after_hook) + + transition = create_transition() + pipeline(transition) + + assert before_calls == [0] + assert after_calls == [0] + + +def test_unregister_hooks(): + """Test unregistering hooks from the pipeline.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step]) + + # Test before_step_hook + before_calls = [] + + def before_hook(idx: int, transition: EnvTransition): + before_calls.append(idx) + + pipeline.register_before_step_hook(before_hook) + + # Verify hook is registered + transition = create_transition() + pipeline(transition) + assert len(before_calls) == 1 + + # Unregister and verify it's no longer called + pipeline.unregister_before_step_hook(before_hook) + before_calls.clear() + pipeline(transition) + assert len(before_calls) == 0 + + # Test after_step_hook + after_calls = [] + + def after_hook(idx: int, transition: EnvTransition): + after_calls.append(idx) + + pipeline.register_after_step_hook(after_hook) + pipeline(transition) + assert len(after_calls) == 1 + + pipeline.unregister_after_step_hook(after_hook) + after_calls.clear() + pipeline(transition) + assert len(after_calls) == 0 + + +def test_unregister_nonexistent_hook(): + """Test error handling when unregistering hooks that don't exist.""" + pipeline = RobotProcessor([MockStep()]) + + def some_hook(idx: int, transition: EnvTransition): + pass + + def reset_hook(): + pass + + # Test unregistering hooks that were never registered + with pytest.raises(ValueError, match="not found in before_step_hooks"): + pipeline.unregister_before_step_hook(some_hook) + + with pytest.raises(ValueError, match="not found in after_step_hooks"): + pipeline.unregister_after_step_hook(some_hook) + + +def test_multiple_hooks_and_selective_unregister(): + """Test registering multiple hooks and selectively unregistering them.""" + pipeline = RobotProcessor([MockStep("step1"), MockStep("step2")]) + + calls_1 = [] + calls_2 = [] + calls_3 = [] + + def hook1(idx: int, transition: EnvTransition): + calls_1.append(f"hook1_step{idx}") + + def hook2(idx: int, transition: EnvTransition): + calls_2.append(f"hook2_step{idx}") + + def hook3(idx: int, transition: EnvTransition): + calls_3.append(f"hook3_step{idx}") + + # Register multiple hooks + pipeline.register_before_step_hook(hook1) + pipeline.register_before_step_hook(hook2) + pipeline.register_before_step_hook(hook3) + + # Run pipeline - all hooks should be called for both steps + transition = create_transition() + pipeline(transition) + + assert calls_1 == ["hook1_step0", "hook1_step1"] + assert calls_2 == ["hook2_step0", "hook2_step1"] + assert calls_3 == ["hook3_step0", "hook3_step1"] + + # Clear calls + calls_1.clear() + calls_2.clear() + calls_3.clear() + + # Unregister middle hook + pipeline.unregister_before_step_hook(hook2) + + # Run again - only hook1 and hook3 should be called + pipeline(transition) + + assert calls_1 == ["hook1_step0", "hook1_step1"] + assert calls_2 == [] # hook2 was unregistered + assert calls_3 == ["hook3_step0", "hook3_step1"] + + +def test_hook_execution_order_documentation(): + """Test and document that hooks are executed sequentially in registration order.""" + pipeline = RobotProcessor([MockStep("step")]) + + execution_order = [] + + def hook_a(idx: int, transition: EnvTransition): + execution_order.append("A") + + def hook_b(idx: int, transition: EnvTransition): + execution_order.append("B") + + def hook_c(idx: int, transition: EnvTransition): + execution_order.append("C") + + # Register in specific order: A, B, C + pipeline.register_before_step_hook(hook_a) + pipeline.register_before_step_hook(hook_b) + pipeline.register_before_step_hook(hook_c) + + transition = create_transition() + pipeline(transition) + + # Verify execution order matches registration order + assert execution_order == ["A", "B", "C"] + + # Test that after unregistering B and re-registering it, it goes to the end + pipeline.unregister_before_step_hook(hook_b) + execution_order.clear() + + pipeline(transition) + assert execution_order == ["A", "C"] # B is gone + + # Re-register B - it should now be at the end + pipeline.register_before_step_hook(hook_b) + execution_order.clear() + + pipeline(transition) + assert execution_order == ["A", "C", "B"] # B is now last + + +def test_save_and_load_pretrained(): + """Test saving and loading pipeline. + + This test demonstrates that JSON-serializable attributes (like counter) + are saved in the config and restored when the step is recreated. + """ + step1 = MockStep("step1") + step2 = MockStep("step2") + + # Increment counters to have some state + step1.counter = 5 + step2.counter = 10 + + pipeline = RobotProcessor([step1, step2], name="TestPipeline") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save pipeline + pipeline.save_pretrained(tmp_dir) + + # Check files were created + config_path = Path(tmp_dir) / "testpipeline.json" # Based on name="TestPipeline" + assert config_path.exists() + + # Check config content + with open(config_path) as f: + config = json.load(f) + + assert config["name"] == "TestPipeline" + assert len(config["steps"]) == 2 + + # Verify counters are saved in config, not in separate state files + assert config["steps"][0]["config"]["counter"] == 5 + assert config["steps"][1]["config"]["counter"] == 10 + + # Load pipeline + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + + assert loaded_pipeline.name == "TestPipeline" + assert len(loaded_pipeline) == 2 + + # Check that counter was restored from config + assert loaded_pipeline.steps[0].counter == 5 + assert loaded_pipeline.steps[1].counter == 10 + + +def test_step_without_optional_methods(): + """Test pipeline with steps that don't implement optional methods.""" + step = MockStepWithoutOptionalMethods(multiplier=3.0) + pipeline = RobotProcessor([step]) + + transition = create_transition(reward=2.0) + result = pipeline(transition) + + assert result[TransitionKey.REWARD] == 6.0 # 2.0 * 3.0 + + # Reset should work even if step doesn't implement reset + pipeline.reset() + + # Save/load should work even without optional methods + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + assert len(loaded_pipeline) == 1 + + +def test_mixed_json_and_tensor_state(): + """Test step with both JSON attributes and tensor state.""" + step = MockStepWithTensorState(name="stats", learning_rate=0.05, window_size=5) + pipeline = RobotProcessor([step]) + + # Process some transitions with rewards + for i in range(10): + transition = create_transition(reward=float(i)) + pipeline(transition) + + # Check state + assert step.running_count.item() == 10 + assert step.learning_rate == 0.05 + + # Save and load + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Check that both config and state files were created + config_path = Path(tmp_dir) / "robotprocessor.json" # Default name is "RobotProcessor" + state_path = Path(tmp_dir) / "robotprocessor_step_0.safetensors" + assert config_path.exists() + assert state_path.exists() + + # Load and verify + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + loaded_step = loaded_pipeline.steps[0] + + # Check JSON attributes were restored + assert loaded_step.name == "stats" + assert loaded_step.learning_rate == 0.05 + assert loaded_step.window_size == 5 + + # Check tensor state was restored + assert loaded_step.running_count.item() == 10 + assert torch.allclose(loaded_step.running_mean, step.running_mean) + + +class MockModuleStep(nn.Module): + """Mock step that inherits from nn.Module to test state_dict handling of module parameters.""" + + def __init__(self, input_dim: int = 10, hidden_dim: int = 5): + super().__init__() + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.linear = nn.Linear(input_dim, hidden_dim) + self.running_mean = nn.Parameter(torch.zeros(hidden_dim), requires_grad=False) + self.counter = 0 # Non-tensor state + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Process transition and update running mean.""" + obs = transition.get(TransitionKey.OBSERVATION) + + if obs is not None and isinstance(obs, torch.Tensor): + # Process observation through linear layer + processed = self.forward(obs[:, : self.input_dim]) + + # Update running mean in-place (don't reassign the parameter) + with torch.no_grad(): + self.running_mean.mul_(0.9).add_(processed.mean(dim=0), alpha=0.1) + + self.counter += 1 + + return transition + + def get_config(self) -> dict[str, Any]: + return { + "input_dim": self.input_dim, + "hidden_dim": self.hidden_dim, + "counter": self.counter, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + """Override to return all module parameters and buffers.""" + # Get the module's state dict (includes all parameters and buffers) + return super().state_dict() + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Override to load all module parameters and buffers.""" + # Use the module's load_state_dict + super().load_state_dict(state) + + def reset(self) -> None: + self.running_mean.zero_() + self.counter = 0 + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + +class MockNonModuleStepWithState: + """Mock step that explicitly does NOT inherit from nn.Module but has tensor state. + + This tests the state_dict/load_state_dict path for regular classes. + """ + + def __init__(self, name: str = "non_module_step", feature_dim: int = 10): + self.name = name + self.feature_dim = feature_dim + + # Initialize tensor state - these are regular tensors, not nn.Parameters + self.weights = torch.randn(feature_dim, feature_dim) + self.bias = torch.zeros(feature_dim) + self.running_stats = torch.zeros(feature_dim) + self.step_count = torch.tensor(0) + + # Non-tensor state + self.config_value = 42 + self.history = [] + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """Process transition using tensor operations.""" + obs = transition.get(TransitionKey.OBSERVATION) + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + + if obs is not None and isinstance(obs, torch.Tensor) and obs.numel() >= self.feature_dim: + # Perform some tensor operations + flat_obs = obs.flatten()[: self.feature_dim] + + # Simple linear transformation (ensure dimensions match for matmul) + output = torch.matmul(self.weights.T, flat_obs) + self.bias + + # Update running stats + self.running_stats = 0.9 * self.running_stats + 0.1 * output + self.step_count += 1 + + # Add to complementary data + comp_data = {} if comp_data is None else dict(comp_data) + comp_data[f"{self.name}_mean_output"] = output.mean().item() + comp_data[f"{self.name}_steps"] = self.step_count.item() + + # Return updated transition + new_transition = transition.copy() + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + return new_transition + + return transition + + def get_config(self) -> dict[str, Any]: + return { + "name": self.name, + "feature_dim": self.feature_dim, + "config_value": self.config_value, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + """Return only tensor state.""" + return { + "weights": self.weights, + "bias": self.bias, + "running_stats": self.running_stats, + "step_count": self.step_count, + } + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + """Load tensor state.""" + self.weights = state["weights"] + self.bias = state["bias"] + self.running_stats = state["running_stats"] + self.step_count = state["step_count"] + + def reset(self) -> None: + """Reset statistics but keep learned parameters.""" + self.running_stats.zero_() + self.step_count.zero_() + self.history.clear() + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + +# Tests for overrides functionality +@dataclass +class MockStepWithNonSerializableParam: + """Mock step that requires a non-serializable parameter.""" + + def __init__(self, name: str = "mock_env_step", multiplier: float = 1.0, env: Any = None): + self.name = name + # Add type validation for multiplier + if isinstance(multiplier, str): + raise ValueError(f"multiplier must be a number, got string '{multiplier}'") + if not isinstance(multiplier, (int, float)): + raise TypeError(f"multiplier must be a number, got {type(multiplier).__name__}") + self.multiplier = float(multiplier) + self.env = env # Non-serializable parameter (like gym.Env) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + reward = transition.get(TransitionKey.REWARD) + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + + # Use the env parameter if provided + if self.env is not None: + comp_data = {} if comp_data is None else dict(comp_data) + comp_data[f"{self.name}_env_info"] = str(self.env) + + # Apply multiplier to reward + new_transition = transition.copy() + if reward is not None: + new_transition[TransitionKey.REWARD] = reward * self.multiplier + + if comp_data: + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + + return new_transition + + def get_config(self) -> dict[str, Any]: + # Note: env is intentionally NOT included here as it's not serializable + return { + "name": self.name, + "multiplier": self.multiplier, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + +@ProcessorStepRegistry.register("registered_mock_step") +@dataclass +class RegisteredMockStep: + """Mock step registered in the registry.""" + + value: int = 42 + device: str = "cpu" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + + comp_data = {} if comp_data is None else dict(comp_data) + comp_data["registered_step_value"] = self.value + comp_data["registered_step_device"] = self.device + + new_transition = transition.copy() + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + return new_transition + + def get_config(self) -> dict[str, Any]: + return { + "value": self.value, + "device": self.device, + } + + def state_dict(self) -> dict[str, torch.Tensor]: + return {} + + def load_state_dict(self, state: dict[str, torch.Tensor]) -> None: + pass + + def reset(self) -> None: + pass + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + +class MockEnvironment: + """Mock environment for testing non-serializable parameters.""" + + def __init__(self, name: str): + self.name = name + + def __str__(self): + return f"MockEnvironment({self.name})" + + +def test_from_pretrained_with_overrides(): + """Test loading processor with parameter overrides.""" + # Create a processor with steps that need overrides + env_step = MockStepWithNonSerializableParam(name="env_step", multiplier=2.0) + registered_step = RegisteredMockStep(value=100, device="cpu") + + pipeline = RobotProcessor([env_step, registered_step], name="TestOverrides") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save the pipeline + pipeline.save_pretrained(tmp_dir) + + # Create a mock environment for override + mock_env = MockEnvironment("test_env") + + # Load with overrides + overrides = { + "MockStepWithNonSerializableParam": { + "env": mock_env, + "multiplier": 3.0, # Override the multiplier too + }, + "registered_mock_step": {"device": "cuda", "value": 200}, + } + + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + # Verify the pipeline was loaded correctly + assert len(loaded_pipeline) == 2 + assert loaded_pipeline.name == "TestOverrides" + + # Test the loaded steps + transition = create_transition(reward=1.0) + result = loaded_pipeline(transition) + + # Check that overrides were applied + comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert "env_step_env_info" in comp_data + assert comp_data["env_step_env_info"] == "MockEnvironment(test_env)" + assert comp_data["registered_step_value"] == 200 + assert comp_data["registered_step_device"] == "cuda" + + # Check that multiplier override was applied + assert result[TransitionKey.REWARD] == 3.0 # 1.0 * 3.0 (overridden multiplier) + + +def test_from_pretrained_with_partial_overrides(): + """Test loading processor with overrides for only some steps.""" + step1 = MockStepWithNonSerializableParam(name="step1", multiplier=1.0) + step2 = MockStepWithNonSerializableParam(name="step2", multiplier=2.0) + + pipeline = RobotProcessor([step1, step2]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Override only one step + overrides = {"MockStepWithNonSerializableParam": {"multiplier": 5.0}} + + # The current implementation applies overrides to ALL steps with the same class name + # Both steps will get the override + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + transition = create_transition(reward=1.0) + result = loaded_pipeline(transition) + + # The reward should be affected by both steps, both getting the override + # First step: 1.0 * 5.0 = 5.0 (overridden) + # Second step: 5.0 * 5.0 = 25.0 (also overridden) + assert result[TransitionKey.REWARD] == 25.0 + + +def test_from_pretrained_invalid_override_key(): + """Test that invalid override keys raise KeyError.""" + step = MockStepWithNonSerializableParam() + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Try to override a non-existent step + overrides = {"NonExistentStep": {"param": "value"}} + + with pytest.raises(KeyError, match="Override keys.*do not match any step"): + RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + +def test_from_pretrained_multiple_invalid_override_keys(): + """Test that multiple invalid override keys are reported.""" + step = MockStepWithNonSerializableParam() + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Try to override multiple non-existent steps + overrides = {"NonExistentStep1": {"param": "value1"}, "NonExistentStep2": {"param": "value2"}} + + with pytest.raises(KeyError) as exc_info: + RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + error_msg = str(exc_info.value) + assert "NonExistentStep1" in error_msg + assert "NonExistentStep2" in error_msg + assert "Available step keys" in error_msg + + +def test_from_pretrained_registered_step_override(): + """Test overriding registered steps using registry names.""" + registered_step = RegisteredMockStep(value=50, device="cpu") + pipeline = RobotProcessor([registered_step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Override using registry name + overrides = {"registered_mock_step": {"value": 999, "device": "cuda"}} + + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + # Test that overrides were applied + transition = create_transition() + result = loaded_pipeline(transition) + + comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert comp_data["registered_step_value"] == 999 + assert comp_data["registered_step_device"] == "cuda" + + +def test_from_pretrained_mixed_registered_and_unregistered(): + """Test overriding both registered and unregistered steps.""" + unregistered_step = MockStepWithNonSerializableParam(name="unregistered", multiplier=1.0) + registered_step = RegisteredMockStep(value=10, device="cpu") + + pipeline = RobotProcessor([unregistered_step, registered_step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + mock_env = MockEnvironment("mixed_test") + + overrides = { + "MockStepWithNonSerializableParam": {"env": mock_env, "multiplier": 4.0}, + "registered_mock_step": {"value": 777}, + } + + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + # Test both steps + transition = create_transition(reward=2.0) + result = loaded_pipeline(transition) + + comp_data = result[TransitionKey.COMPLEMENTARY_DATA] + assert comp_data["unregistered_env_info"] == "MockEnvironment(mixed_test)" + assert comp_data["registered_step_value"] == 777 + assert result[TransitionKey.REWARD] == 8.0 # 2.0 * 4.0 + + +def test_from_pretrained_no_overrides(): + """Test that from_pretrained works without overrides (backward compatibility).""" + step = MockStepWithNonSerializableParam(name="no_override", multiplier=3.0) + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Load without overrides + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + + assert len(loaded_pipeline) == 1 + + # Test that the step works (env will be None) + transition = create_transition(reward=1.0) + result = loaded_pipeline(transition) + + assert result[TransitionKey.REWARD] == 3.0 # 1.0 * 3.0 + + +def test_from_pretrained_empty_overrides(): + """Test that from_pretrained works with empty overrides dict.""" + step = MockStepWithNonSerializableParam(multiplier=2.0) + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Load with empty overrides + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides={}) + + assert len(loaded_pipeline) == 1 + + # Test that the step works normally + transition = create_transition(reward=1.0) + result = loaded_pipeline(transition) + + assert result[TransitionKey.REWARD] == 2.0 + + +def test_from_pretrained_override_instantiation_error(): + """Test that instantiation errors with overrides are properly reported.""" + step = MockStepWithNonSerializableParam(multiplier=1.0) + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Try to override with invalid parameter type + overrides = { + "MockStepWithNonSerializableParam": { + "multiplier": "invalid_type" # Should be float, not string + } + } + + with pytest.raises(ValueError, match="Failed to instantiate processor step"): + RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + +def test_from_pretrained_with_state_and_overrides(): + """Test that overrides work correctly with steps that have tensor state.""" + step = MockStepWithTensorState(name="tensor_step", learning_rate=0.01, window_size=5) + pipeline = RobotProcessor([step]) + + # Process some data to create state + for i in range(10): + transition = create_transition(reward=float(i)) + pipeline(transition) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Load with overrides + overrides = { + "MockStepWithTensorState": { + "learning_rate": 0.05, # Override learning rate + "window_size": 3, # Override window size + } + } + + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + loaded_step = loaded_pipeline.steps[0] + + # Check that config overrides were applied + assert loaded_step.learning_rate == 0.05 + assert loaded_step.window_size == 3 + + # Check that tensor state was preserved + assert loaded_step.running_count.item() == 10 + + # The running_mean should still have the original window_size (5) from saved state + # but the new step will use window_size=3 for future operations + assert loaded_step.running_mean.shape[0] == 5 # From saved state + + +def test_from_pretrained_override_error_messages(): + """Test that error messages for override failures are helpful.""" + step1 = MockStepWithNonSerializableParam(name="step1") + step2 = RegisteredMockStep() + pipeline = RobotProcessor([step1, step2]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Test with invalid override key + overrides = {"WrongStepName": {"param": "value"}} + + with pytest.raises(KeyError) as exc_info: + RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + error_msg = str(exc_info.value) + assert "WrongStepName" in error_msg + assert "Available step keys" in error_msg + assert "MockStepWithNonSerializableParam" in error_msg + assert "registered_mock_step" in error_msg + + +def test_repr_empty_processor(): + """Test __repr__ with empty processor.""" + pipeline = RobotProcessor() + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='RobotProcessor', steps=0: [])" + assert repr_str == expected + + +def test_repr_single_step(): + """Test __repr__ with single step.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step]) + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep])" + assert repr_str == expected + + +def test_repr_multiple_steps_under_limit(): + """Test __repr__ with 2-3 steps (all shown).""" + step1 = MockStep("step1") + step2 = MockStepWithoutOptionalMethods() + pipeline = RobotProcessor([step1, step2]) + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='RobotProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods])" + assert repr_str == expected + + # Test with 3 steps (boundary case) + step3 = MockStepWithTensorState() + pipeline = RobotProcessor([step1, step2, step3]) + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='RobotProcessor', steps=3: [MockStep, MockStepWithoutOptionalMethods, MockStepWithTensorState])" + assert repr_str == expected + + +def test_repr_many_steps_truncated(): + """Test __repr__ with more than 3 steps (truncated with ellipsis).""" + step1 = MockStep("step1") + step2 = MockStepWithoutOptionalMethods() + step3 = MockStepWithTensorState() + step4 = MockModuleStep() + step5 = MockNonModuleStepWithState() + + pipeline = RobotProcessor([step1, step2, step3, step4, step5]) + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='RobotProcessor', steps=5: [MockStep, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])" + assert repr_str == expected + + +def test_repr_with_custom_name(): + """Test __repr__ with custom processor name.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step], name="CustomProcessor") + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='CustomProcessor', steps=1: [MockStep])" + assert repr_str == expected + + +def test_repr_with_seed(): + """Test __repr__ with seed parameter.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step]) + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='RobotProcessor', steps=1: [MockStep])" + assert repr_str == expected + + +def test_repr_with_custom_name_and_seed(): + """Test __repr__ with both custom name and seed.""" + step1 = MockStep("step1") + step2 = MockStepWithoutOptionalMethods() + pipeline = RobotProcessor([step1, step2], name="MyProcessor") + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='MyProcessor', steps=2: [MockStep, MockStepWithoutOptionalMethods])" + assert repr_str == expected + + +def test_repr_without_seed(): + """Test __repr__ when seed is explicitly None (should not show seed).""" + step = MockStep("test_step") + pipeline = RobotProcessor([step], name="TestProcessor") + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='TestProcessor', steps=1: [MockStep])" + assert repr_str == expected + + +def test_repr_various_step_types(): + """Test __repr__ with different types of steps to verify class name extraction.""" + step1 = MockStep() + step2 = MockStepWithTensorState() + step3 = MockModuleStep() + step4 = MockNonModuleStepWithState() + + pipeline = RobotProcessor([step1, step2, step3, step4], name="MixedSteps") + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='MixedSteps', steps=4: [MockStep, MockStepWithTensorState, ..., MockNonModuleStepWithState])" + assert repr_str == expected + + +def test_repr_edge_case_long_names(): + """Test __repr__ handles steps with long class names properly.""" + step1 = MockStepWithNonSerializableParam() + step2 = MockStepWithoutOptionalMethods() + step3 = MockStepWithTensorState() + step4 = MockNonModuleStepWithState() + + pipeline = RobotProcessor([step1, step2, step3, step4], name="LongNames") + repr_str = repr(pipeline) + + expected = "RobotProcessor(name='LongNames', steps=4: [MockStepWithNonSerializableParam, MockStepWithoutOptionalMethods, ..., MockNonModuleStepWithState])" + assert repr_str == expected + + +# Tests for config filename features and multiple processors +def test_save_with_custom_config_filename(): + """Test saving processor with custom config filename.""" + step = MockStep("test") + pipeline = RobotProcessor([step], name="TestProcessor") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save with custom filename + pipeline.save_pretrained(tmp_dir, config_filename="my_custom_config.json") + + # Check file exists + config_path = Path(tmp_dir) / "my_custom_config.json" + assert config_path.exists() + + # Check content + with open(config_path) as f: + config = json.load(f) + assert config["name"] == "TestProcessor" + + # Load with specific filename + loaded = RobotProcessor.from_pretrained(tmp_dir, config_filename="my_custom_config.json") + assert loaded.name == "TestProcessor" + + +def test_multiple_processors_same_directory(): + """Test saving multiple processors to the same directory with different config files.""" + # Create different processors + preprocessor = RobotProcessor([MockStep("pre1"), MockStep("pre2")], name="preprocessor") + + postprocessor = RobotProcessor([MockStepWithoutOptionalMethods(multiplier=0.5)], name="postprocessor") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save both to same directory + preprocessor.save_pretrained(tmp_dir) + postprocessor.save_pretrained(tmp_dir) + + # Check both config files exist + assert (Path(tmp_dir) / "preprocessor.json").exists() + assert (Path(tmp_dir) / "postprocessor.json").exists() + + # Load them back + loaded_pre = RobotProcessor.from_pretrained(tmp_dir, config_filename="preprocessor.json") + loaded_post = RobotProcessor.from_pretrained(tmp_dir, config_filename="postprocessor.json") + + assert loaded_pre.name == "preprocessor" + assert loaded_post.name == "postprocessor" + assert len(loaded_pre) == 2 + assert len(loaded_post) == 1 + + +def test_auto_detect_single_config(): + """Test automatic config detection when there's only one JSON file.""" + step = MockStepWithTensorState() + pipeline = RobotProcessor([step], name="SingleConfig") + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Load without specifying config_filename + loaded = RobotProcessor.from_pretrained(tmp_dir) + assert loaded.name == "SingleConfig" + + +def test_error_multiple_configs_no_filename(): + """Test error when multiple configs exist and no filename specified.""" + proc1 = RobotProcessor([MockStep()], name="processor1") + proc2 = RobotProcessor([MockStep()], name="processor2") + + with tempfile.TemporaryDirectory() as tmp_dir: + proc1.save_pretrained(tmp_dir) + proc2.save_pretrained(tmp_dir) + + # Should raise error + with pytest.raises(ValueError, match="Multiple .json files found"): + RobotProcessor.from_pretrained(tmp_dir) + + +def test_state_file_naming_with_indices(): + """Test that state files include pipeline name and step indices to avoid conflicts.""" + # Create multiple steps of same type with state + step1 = MockStepWithTensorState(name="norm1", window_size=5) + step2 = MockStepWithTensorState(name="norm2", window_size=10) + step3 = MockModuleStep(input_dim=5) + + pipeline = RobotProcessor([step1, step2, step3]) + + # Process some data to create state + for i in range(5): + transition = create_transition(observation=torch.randn(2, 5), reward=float(i)) + pipeline(transition) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Check state files have indices + state_files = sorted(Path(tmp_dir).glob("*.safetensors")) + assert len(state_files) == 3 + + # Files should be named with pipeline name prefix and indices + expected_names = [ + "robotprocessor_step_0.safetensors", + "robotprocessor_step_1.safetensors", + "robotprocessor_step_2.safetensors", + ] + actual_names = [f.name for f in state_files] + assert actual_names == expected_names + + +def test_state_file_naming_with_registry(): + """Test state file naming for registered steps includes pipeline name, index and registry name.""" + + # Register a test step + @ProcessorStepRegistry.register("test_stateful_step") + @dataclass + class TestStatefulStep: + value: int = 0 + + def __init__(self, value: int = 0): + self.value = value + self.state_tensor = torch.randn(3, 3) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def get_config(self): + return {"value": self.value} + + def state_dict(self): + return {"state_tensor": self.state_tensor} + + def load_state_dict(self, state): + self.state_tensor = state["state_tensor"] + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + try: + # Create pipeline with registered steps + step1 = TestStatefulStep(1) + step2 = TestStatefulStep(2) + pipeline = RobotProcessor([step1, step2]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Check state files + state_files = sorted(Path(tmp_dir).glob("*.safetensors")) + assert len(state_files) == 2 + + # Should include pipeline name, index and registry name + expected_names = [ + "robotprocessor_step_0_test_stateful_step.safetensors", + "robotprocessor_step_1_test_stateful_step.safetensors", + ] + actual_names = [f.name for f in state_files] + assert actual_names == expected_names + + finally: + # Cleanup registry + ProcessorStepRegistry.unregister("test_stateful_step") + + +# More comprehensive override tests +def test_override_with_nested_config(): + """Test overrides with nested configuration dictionaries.""" + + @ProcessorStepRegistry.register("complex_config_step") + @dataclass + class ComplexConfigStep: + name: str = "complex" + simple_param: int = 42 + nested_config: dict = None + + def __post_init__(self): + if self.nested_config is None: + self.nested_config = {"level1": {"level2": "default"}} + + def __call__(self, transition: EnvTransition) -> EnvTransition: + comp_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + comp_data = dict(comp_data) + comp_data["config_value"] = self.nested_config.get("level1", {}).get("level2", "missing") + + new_transition = transition.copy() + new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp_data + return new_transition + + def get_config(self): + return {"name": self.name, "simple_param": self.simple_param, "nested_config": self.nested_config} + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + try: + step = ComplexConfigStep() + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Load with nested override + loaded = RobotProcessor.from_pretrained( + tmp_dir, + overrides={"complex_config_step": {"nested_config": {"level1": {"level2": "overridden"}}}}, + ) + + # Test that override worked + transition = create_transition() + result = loaded(transition) + assert result[TransitionKey.COMPLEMENTARY_DATA]["config_value"] == "overridden" + finally: + ProcessorStepRegistry.unregister("complex_config_step") + + +def test_override_preserves_defaults(): + """Test that overrides only affect specified parameters.""" + step = MockStepWithNonSerializableParam(name="test", multiplier=2.0) + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Override only one parameter + loaded = RobotProcessor.from_pretrained( + tmp_dir, + overrides={ + "MockStepWithNonSerializableParam": { + "multiplier": 5.0 # Only override multiplier + } + }, + ) + + # Check that name was preserved from saved config + loaded_step = loaded.steps[0] + assert loaded_step.name == "test" # Original value + assert loaded_step.multiplier == 5.0 # Overridden value + + +def test_override_type_validation(): + """Test that type errors in overrides are caught properly.""" + step = MockStepWithTensorState(learning_rate=0.01) + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Try to override with wrong type + overrides = { + "MockStepWithTensorState": { + "window_size": "not_an_int" # Should be int + } + } + + with pytest.raises(ValueError, match="Failed to instantiate"): + RobotProcessor.from_pretrained(tmp_dir, overrides=overrides) + + +def test_override_with_callables(): + """Test overriding with callable objects.""" + + @ProcessorStepRegistry.register("callable_step") + @dataclass + class CallableStep: + name: str = "callable_step" + transform_fn: Any = None + + def __call__(self, transition: EnvTransition) -> EnvTransition: + obs = transition.get(TransitionKey.OBSERVATION) + if obs is not None and self.transform_fn is not None: + processed_obs = {} + for k, v in obs.items(): + processed_obs[k] = self.transform_fn(v) + + new_transition = transition.copy() + new_transition[TransitionKey.OBSERVATION] = processed_obs + return new_transition + return transition + + def get_config(self): + return {"name": self.name} + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + try: + step = CallableStep() + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Define a transform function + def double_values(x): + if isinstance(x, (int, float)): + return x * 2 + elif isinstance(x, torch.Tensor): + return x * 2 + return x + + # Load with callable override + loaded = RobotProcessor.from_pretrained( + tmp_dir, overrides={"callable_step": {"transform_fn": double_values}} + ) + + # Test it works + transition = create_transition(observation={"value": torch.tensor(5.0)}) + result = loaded(transition) + assert result[TransitionKey.OBSERVATION]["value"].item() == 10.0 + finally: + ProcessorStepRegistry.unregister("callable_step") + + +def test_override_multiple_same_class_warning(): + """Test behavior when multiple steps of same class exist.""" + step1 = MockStepWithNonSerializableParam(name="step1", multiplier=1.0) + step2 = MockStepWithNonSerializableParam(name="step2", multiplier=2.0) + pipeline = RobotProcessor([step1, step2]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Override affects all instances of the class + loaded = RobotProcessor.from_pretrained( + tmp_dir, overrides={"MockStepWithNonSerializableParam": {"multiplier": 10.0}} + ) + + # Both steps get the same override + assert loaded.steps[0].multiplier == 10.0 + assert loaded.steps[1].multiplier == 10.0 + + # But original names are preserved + assert loaded.steps[0].name == "step1" + assert loaded.steps[1].name == "step2" + + +def test_config_filename_special_characters(): + """Test config filenames with special characters are sanitized.""" + # Processor name with special characters + pipeline = RobotProcessor([MockStep()], name="My/Processor\\With:Special*Chars") + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Check that filename was sanitized + json_files = list(Path(tmp_dir).glob("*.json")) + assert len(json_files) == 1 + + # Should have replaced special chars with underscores + expected_name = "my_processor_with_special_chars.json" + assert json_files[0].name == expected_name + + +def test_state_file_naming_with_multiple_processors(): + """Test that state files are properly prefixed with pipeline names to avoid conflicts.""" + # Create two processors with state + step1 = MockStepWithTensorState(name="norm", window_size=5) + preprocessor = RobotProcessor([step1], name="PreProcessor") + + step2 = MockStepWithTensorState(name="norm", window_size=10) + postprocessor = RobotProcessor([step2], name="PostProcessor") + + # Process some data to create state + for i in range(3): + transition = create_transition(reward=float(i)) + preprocessor(transition) + postprocessor(transition) + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save both processors to the same directory + preprocessor.save_pretrained(tmp_dir) + postprocessor.save_pretrained(tmp_dir) + + # Check that all files exist and are distinct + assert (Path(tmp_dir) / "preprocessor.json").exists() + assert (Path(tmp_dir) / "postprocessor.json").exists() + assert (Path(tmp_dir) / "preprocessor_step_0.safetensors").exists() + assert (Path(tmp_dir) / "postprocessor_step_0.safetensors").exists() + + # Load both back and verify they work correctly + loaded_pre = RobotProcessor.from_pretrained(tmp_dir, config_filename="preprocessor.json") + loaded_post = RobotProcessor.from_pretrained(tmp_dir, config_filename="postprocessor.json") + + assert loaded_pre.name == "PreProcessor" + assert loaded_post.name == "PostProcessor" + assert loaded_pre.steps[0].window_size == 5 + assert loaded_post.steps[0].window_size == 10 + + +def test_override_with_device_strings(): + """Test overriding device parameters with string values.""" + + @ProcessorStepRegistry.register("device_aware_step") + @dataclass + class DeviceAwareStep: + device: str = "cpu" + + def __init__(self, device: str = "cpu"): + self.device = device + self.buffer = torch.zeros(10, device=device) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def get_config(self): + return {"device": str(self.device)} + + def state_dict(self): + return {"buffer": self.buffer} + + def load_state_dict(self, state): + self.buffer = state["buffer"] + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + # We do not test feature_contract here + return features + + try: + step = DeviceAwareStep(device="cpu") + pipeline = RobotProcessor([step]) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Override device + if torch.cuda.is_available(): + loaded = RobotProcessor.from_pretrained( + tmp_dir, overrides={"device_aware_step": {"device": "cuda:0"}} + ) + + loaded_step = loaded.steps[0] + assert loaded_step.device == "cuda:0" + # Note: buffer will still be on CPU from saved state + # until .to() is called on the processor + + finally: + ProcessorStepRegistry.unregister("device_aware_step") + + +def test_from_pretrained_nonexistent_path(): + """Test error handling when loading from non-existent sources.""" + from huggingface_hub.errors import HfHubHTTPError, HFValidationError + + # Test with an invalid repo ID (too many slashes) - caught by HF validation + with pytest.raises(HFValidationError): + RobotProcessor.from_pretrained("/path/that/does/not/exist") + + # Test with a non-existent but valid Hub repo format + with pytest.raises((FileNotFoundError, HfHubHTTPError)): + RobotProcessor.from_pretrained("nonexistent-user/nonexistent-repo") + + # Test with a local directory that exists but has no config files + with tempfile.TemporaryDirectory() as tmp_dir: + with pytest.raises(FileNotFoundError, match="No .json configuration files found"): + RobotProcessor.from_pretrained(tmp_dir) + + +def test_save_load_with_custom_converter_functions(): + """Test that custom to_transition and to_output functions are NOT saved.""" + + def custom_to_transition(batch): + # Custom conversion logic + return { + TransitionKey.OBSERVATION: batch.get("obs"), + TransitionKey.ACTION: batch.get("act"), + TransitionKey.REWARD: batch.get("rew", 0.0), + TransitionKey.DONE: batch.get("done", False), + TransitionKey.TRUNCATED: batch.get("truncated", False), + TransitionKey.INFO: {}, + TransitionKey.COMPLEMENTARY_DATA: {}, + } + + def custom_to_output(transition): + # Custom output format + return { + "obs": transition.get(TransitionKey.OBSERVATION), + "act": transition.get(TransitionKey.ACTION), + "rew": transition.get(TransitionKey.REWARD), + "done": transition.get(TransitionKey.DONE), + "truncated": transition.get(TransitionKey.TRUNCATED), + } + + # Create processor with custom converters + pipeline = RobotProcessor([MockStep()], to_transition=custom_to_transition, to_output=custom_to_output) + + with tempfile.TemporaryDirectory() as tmp_dir: + pipeline.save_pretrained(tmp_dir) + + # Load - should use default converters + loaded = RobotProcessor.from_pretrained(tmp_dir) + + # Verify it uses default converters by checking with standard batch format + batch = { + "observation.image": torch.randn(1, 3, 32, 32), + "action": torch.randn(1, 7), + "next.reward": torch.tensor([1.0]), + "next.done": torch.tensor([False]), + "next.truncated": torch.tensor([False]), + "info": {}, + } + + # Should work with standard format (wouldn't work with custom converter) + result = loaded(batch) + assert "observation.image" in result # Standard format preserved + + +class NonCompliantStep: + """Intentionally non-compliant: missing feature_contract.""" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + +def test_construction_rejects_step_without_feature_contract(): + with pytest.raises(TypeError, match=r"must define feature_contract\(features\) -> dict\[str, Any\]"): + RobotProcessor([NonCompliantStep()]) + + +class NonCallableStep: + """Intentionally non-compliant: missing __call__.""" + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return features + + +def test_construction_rejects_step_without_call(): + with pytest.raises(TypeError, match=r"must define __call__"): + RobotProcessor([NonCallableStep()]) + + +@dataclass +class FeatureContractAddStep: + """Adds a PolicyFeature""" + + key: str = "a" + value: PolicyFeature = PolicyFeature(type=FeatureType.STATE, shape=(1,)) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + features[self.key] = self.value + return features + + +@dataclass +class FeatureContractMutateStep: + """Mutates a PolicyFeature""" + + key: str = "a" + fn: Callable[[PolicyFeature | None], PolicyFeature] = lambda x: x # noqa: E731 + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + features[self.key] = self.fn(features.get(self.key)) + return features + + +@dataclass +class FeatureContractBadReturnStep: + """Returns a non-dict""" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + return ["not-a-dict"] + + +@dataclass +class FeatureContractRemoveStep: + """Removes a PolicyFeature""" + + key: str + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + features.pop(self.key, None) + return features + + +def test_feature_contract_orders_and_merges(policy_feature_factory): + p = RobotProcessor( + [ + FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))), + FeatureContractMutateStep("a", lambda v: PolicyFeature(type=v.type, shape=(3,))), + FeatureContractAddStep("b", policy_feature_factory(FeatureType.ENV, (2,))), + ] + ) + out = p.feature_contract({}) + + assert out["a"].type == FeatureType.STATE and out["a"].shape == (3,) + assert out["b"].type == FeatureType.ENV and out["b"].shape == (2,) + assert_contract_is_typed(out) + + +def test_feature_contract_respects_initial_without_mutation(policy_feature_factory): + initial = { + "seed": policy_feature_factory(FeatureType.STATE, (7,)), + "nested": policy_feature_factory(FeatureType.ENV, (0,)), + } + p = RobotProcessor( + [ + FeatureContractMutateStep("seed", lambda v: PolicyFeature(type=v.type, shape=(v.shape[0] + 1,))), + FeatureContractMutateStep( + "nested", lambda v: PolicyFeature(type=v.type, shape=(v.shape[0] + 5,)) + ), + ] + ) + out = p.feature_contract(initial_features=initial) + + assert out["seed"].shape == (8,) + assert out["nested"].shape == (5,) + # Initial dict must be preserved + assert initial["seed"].shape == (7,) + assert initial["nested"].shape == (0,) + + assert_contract_is_typed(out) + + +def test_feature_contract_type_error_on_bad_step(): + p = RobotProcessor([FeatureContractAddStep(), FeatureContractBadReturnStep()]) + with pytest.raises(TypeError, match=r"\w+\.feature_contract must return dict\[str, Any\]"): + _ = p.feature_contract({}) + + +def test_feature_contract_execution_order_tracking(): + class Track: + def __init__(self, label): + self.label = label + + def __call__(self, transition: EnvTransition) -> EnvTransition: + return transition + + def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]: + code = {"A": 1, "B": 2, "C": 3}[self.label] + pf = features.get("order", PolicyFeature(type=FeatureType.ENV, shape=())) + features["order"] = PolicyFeature(type=pf.type, shape=pf.shape + (code,)) + return features + + out = RobotProcessor([Track("A"), Track("B"), Track("C")]).feature_contract({}) + assert out["order"].shape == (1, 2, 3) + + +def test_feature_contract_remove_key(policy_feature_factory): + p = RobotProcessor( + [ + FeatureContractAddStep("a", policy_feature_factory(FeatureType.STATE, (1,))), + FeatureContractRemoveStep("a"), + ] + ) + out = p.feature_contract({}) + assert "a" not in out + + +def test_feature_contract_remove_from_initial(policy_feature_factory): + initial = { + "keep": policy_feature_factory(FeatureType.STATE, (1,)), + "drop": policy_feature_factory(FeatureType.STATE, (1,)), + } + p = RobotProcessor([FeatureContractRemoveStep("drop")]) + out = p.feature_contract(initial_features=initial) + assert "drop" not in out and out["keep"] == initial["keep"] diff --git a/tests/processor/test_rename_processor.py b/tests/processor/test_rename_processor.py new file mode 100644 index 000000000..229d57f9f --- /dev/null +++ b/tests/processor/test_rename_processor.py @@ -0,0 +1,467 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +from pathlib import Path + +import numpy as np +import torch + +from lerobot.configs.types import FeatureType +from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey +from tests.conftest import assert_contract_is_typed + + +def create_transition( + observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None +): + """Helper to create an EnvTransition dictionary.""" + return { + TransitionKey.OBSERVATION: observation, + TransitionKey.ACTION: action, + TransitionKey.REWARD: reward, + TransitionKey.DONE: done, + TransitionKey.TRUNCATED: truncated, + TransitionKey.INFO: info, + TransitionKey.COMPLEMENTARY_DATA: complementary_data, + } + + +def test_basic_renaming(): + """Test basic key renaming functionality.""" + rename_map = { + "old_key1": "new_key1", + "old_key2": "new_key2", + } + processor = RenameProcessor(rename_map=rename_map) + + observation = { + "old_key1": torch.tensor([1.0, 2.0]), + "old_key2": np.array([3.0, 4.0]), + "unchanged_key": "keep_me", + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check renamed keys + assert "new_key1" in processed_obs + assert "new_key2" in processed_obs + assert "old_key1" not in processed_obs + assert "old_key2" not in processed_obs + + # Check values are preserved + torch.testing.assert_close(processed_obs["new_key1"], torch.tensor([1.0, 2.0])) + np.testing.assert_array_equal(processed_obs["new_key2"], np.array([3.0, 4.0])) + + # Check unchanged key is preserved + assert processed_obs["unchanged_key"] == "keep_me" + + +def test_empty_rename_map(): + """Test processor with empty rename map (should pass through unchanged).""" + processor = RenameProcessor(rename_map={}) + + observation = { + "key1": torch.tensor([1.0]), + "key2": "value2", + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # All keys should be unchanged + assert processed_obs.keys() == observation.keys() + torch.testing.assert_close(processed_obs["key1"], observation["key1"]) + assert processed_obs["key2"] == observation["key2"] + + +def test_none_observation(): + """Test processor with None observation.""" + processor = RenameProcessor(rename_map={"old": "new"}) + + transition = create_transition() + result = processor(transition) + + # Should return transition unchanged + assert result == transition + + +def test_overlapping_rename(): + """Test renaming when new names might conflict.""" + rename_map = { + "a": "b", + "b": "c", # This creates a potential conflict + } + processor = RenameProcessor(rename_map=rename_map) + + observation = { + "a": 1, + "b": 2, + "x": 3, + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that renaming happens correctly + assert "a" not in processed_obs + assert processed_obs["b"] == 1 # 'a' renamed to 'b' + assert processed_obs["c"] == 2 # original 'b' renamed to 'c' + assert processed_obs["x"] == 3 + + +def test_partial_rename(): + """Test renaming only some keys.""" + rename_map = { + "observation.state": "observation.proprio_state", + "pixels": "observation.image", + } + processor = RenameProcessor(rename_map=rename_map) + + observation = { + "observation.state": torch.randn(10), + "pixels": np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8), + "reward": 1.0, + "info": {"episode": 1}, + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check renamed keys + assert "observation.proprio_state" in processed_obs + assert "observation.image" in processed_obs + assert "observation.state" not in processed_obs + assert "pixels" not in processed_obs + + # Check unchanged keys + assert processed_obs["reward"] == 1.0 + assert processed_obs["info"] == {"episode": 1} + + +def test_get_config(): + """Test configuration serialization.""" + rename_map = { + "old1": "new1", + "old2": "new2", + } + processor = RenameProcessor(rename_map=rename_map) + + config = processor.get_config() + assert config == {"rename_map": rename_map} + + +def test_state_dict(): + """Test state dict (should be empty for RenameProcessor).""" + processor = RenameProcessor(rename_map={"old": "new"}) + + state = processor.state_dict() + assert state == {} + + # Load state dict should work even with empty dict + processor.load_state_dict({}) + + +def test_integration_with_robot_processor(): + """Test integration with RobotProcessor pipeline.""" + rename_map = { + "agent_pos": "observation.state", + "pixels": "observation.image", + } + rename_processor = RenameProcessor(rename_map=rename_map) + + pipeline = RobotProcessor([rename_processor]) + + observation = { + "agent_pos": np.array([1.0, 2.0, 3.0]), + "pixels": np.zeros((32, 32, 3), dtype=np.uint8), + "other_data": "preserve_me", + } + transition = create_transition( + observation=observation, reward=0.5, done=False, truncated=False, info={}, complementary_data={} + ) + + result = pipeline(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check renaming worked through pipeline + assert "observation.state" in processed_obs + assert "observation.image" in processed_obs + assert "agent_pos" not in processed_obs + assert "pixels" not in processed_obs + assert processed_obs["other_data"] == "preserve_me" + + # Check other transition elements unchanged + assert result[TransitionKey.REWARD] == 0.5 + assert result[TransitionKey.DONE] is False + + +def test_save_and_load_pretrained(): + """Test saving and loading processor with RobotProcessor.""" + rename_map = { + "old_state": "observation.state", + "old_image": "observation.image", + } + processor = RenameProcessor(rename_map=rename_map) + pipeline = RobotProcessor([processor], name="TestRenameProcessor") + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save pipeline + pipeline.save_pretrained(tmp_dir) + + # Check files were created + config_path = Path(tmp_dir) / "testrenameprocessor.json" # Based on name="TestRenameProcessor" + assert config_path.exists() + + # No state files should be created for RenameProcessor + state_files = list(Path(tmp_dir).glob("*.safetensors")) + assert len(state_files) == 0 + + # Load pipeline + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + + assert loaded_pipeline.name == "TestRenameProcessor" + assert len(loaded_pipeline) == 1 + + # Check that loaded processor works correctly + loaded_processor = loaded_pipeline.steps[0] + assert isinstance(loaded_processor, RenameProcessor) + assert loaded_processor.rename_map == rename_map + + # Test functionality after loading + observation = {"old_state": [1, 2, 3], "old_image": "image_data"} + transition = create_transition(observation=observation) + + result = loaded_pipeline(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + assert "observation.state" in processed_obs + assert "observation.image" in processed_obs + assert processed_obs["observation.state"] == [1, 2, 3] + assert processed_obs["observation.image"] == "image_data" + + +def test_registry_functionality(): + """Test that RenameProcessor is properly registered.""" + # Check that it's registered + assert "rename_processor" in ProcessorStepRegistry.list() + + # Get from registry + retrieved_class = ProcessorStepRegistry.get("rename_processor") + assert retrieved_class is RenameProcessor + + # Create instance from registry + instance = retrieved_class(rename_map={"old": "new"}) + assert isinstance(instance, RenameProcessor) + assert instance.rename_map == {"old": "new"} + + +def test_registry_based_save_load(): + """Test save/load using registry name instead of module path.""" + processor = RenameProcessor(rename_map={"key1": "renamed_key1"}) + pipeline = RobotProcessor([processor]) + + with tempfile.TemporaryDirectory() as tmp_dir: + # Save and load + pipeline.save_pretrained(tmp_dir) + + # Verify config uses registry name + import json + + with open(Path(tmp_dir) / "robotprocessor.json") as f: # Default name is "RobotProcessor" + config = json.load(f) + + assert "registry_name" in config["steps"][0] + assert config["steps"][0]["registry_name"] == "rename_processor" + assert "class" not in config["steps"][0] # Should use registry, not module path + + # Load should work + loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir) + loaded_processor = loaded_pipeline.steps[0] + assert isinstance(loaded_processor, RenameProcessor) + assert loaded_processor.rename_map == {"key1": "renamed_key1"} + + +def test_chained_rename_processors(): + """Test multiple RenameProcessors in a pipeline.""" + # First processor: rename raw keys to intermediate format + processor1 = RenameProcessor( + rename_map={ + "pos": "agent_position", + "img": "camera_image", + } + ) + + # Second processor: rename to final format + processor2 = RenameProcessor( + rename_map={ + "agent_position": "observation.state", + "camera_image": "observation.image", + } + ) + + pipeline = RobotProcessor([processor1, processor2]) + + observation = { + "pos": np.array([1.0, 2.0]), + "img": "image_data", + "extra": "keep_me", + } + transition = create_transition(observation=observation) + + # Step through to see intermediate results + results = list(pipeline.step_through(transition)) + + # After first processor + assert "agent_position" in results[1][TransitionKey.OBSERVATION] + assert "camera_image" in results[1][TransitionKey.OBSERVATION] + + # After second processor + final_obs = results[2][TransitionKey.OBSERVATION] + assert "observation.state" in final_obs + assert "observation.image" in final_obs + assert final_obs["extra"] == "keep_me" + + # Original keys should be gone + assert "pos" not in final_obs + assert "img" not in final_obs + assert "agent_position" not in final_obs + assert "camera_image" not in final_obs + + +def test_nested_observation_rename(): + """Test renaming with nested observation structures.""" + rename_map = { + "observation.images.left": "observation.camera.left_view", + "observation.images.right": "observation.camera.right_view", + "observation.proprio": "observation.proprioception", + } + processor = RenameProcessor(rename_map=rename_map) + + observation = { + "observation.images.left": torch.randn(3, 64, 64), + "observation.images.right": torch.randn(3, 64, 64), + "observation.proprio": torch.randn(7), + "observation.gripper": torch.tensor([0.0]), # Not renamed + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check renames + assert "observation.camera.left_view" in processed_obs + assert "observation.camera.right_view" in processed_obs + assert "observation.proprioception" in processed_obs + + # Check unchanged key + assert "observation.gripper" in processed_obs + + # Check old keys removed + assert "observation.images.left" not in processed_obs + assert "observation.images.right" not in processed_obs + assert "observation.proprio" not in processed_obs + + +def test_value_types_preserved(): + """Test that various value types are preserved during renaming.""" + rename_map = {"old_tensor": "new_tensor", "old_array": "new_array", "old_scalar": "new_scalar"} + processor = RenameProcessor(rename_map=rename_map) + + tensor_value = torch.randn(3, 3) + array_value = np.random.rand(2, 2) + + observation = { + "old_tensor": tensor_value, + "old_array": array_value, + "old_scalar": 42, + "old_string": "hello", + "old_dict": {"nested": "value"}, + "old_list": [1, 2, 3], + } + transition = create_transition(observation=observation) + + result = processor(transition) + processed_obs = result[TransitionKey.OBSERVATION] + + # Check that values and types are preserved + assert torch.equal(processed_obs["new_tensor"], tensor_value) + assert np.array_equal(processed_obs["new_array"], array_value) + assert processed_obs["new_scalar"] == 42 + assert processed_obs["old_string"] == "hello" + assert processed_obs["old_dict"] == {"nested": "value"} + assert processed_obs["old_list"] == [1, 2, 3] + + +def test_feature_contract_basic_renaming(policy_feature_factory): + processor = RenameProcessor(rename_map={"a": "x", "b": "y"}) + features = { + "a": policy_feature_factory(FeatureType.STATE, (2,)), + "b": policy_feature_factory(FeatureType.ACTION, (3,)), + "c": policy_feature_factory(FeatureType.ENV, (1,)), + } + + out = processor.feature_contract(features.copy()) + + # Values preserved and typed + assert out["x"] == features["a"] + assert out["y"] == features["b"] + assert out["c"] == features["c"] + + assert_contract_is_typed(out) + # Input not mutated + assert set(features) == {"a", "b", "c"} + + +def test_feature_contract_overlapping_keys(policy_feature_factory): + # Overlapping renames: both 'a' and 'b' exist. 'a'->'b', 'b'->'c' + processor = RenameProcessor(rename_map={"a": "b", "b": "c"}) + features = { + "a": policy_feature_factory(FeatureType.STATE, (1,)), + "b": policy_feature_factory(FeatureType.STATE, (2,)), + } + out = processor.feature_contract(features) + + assert set(out) == {"b", "c"} + assert out["b"] == features["a"] # 'a' renamed to'b' + assert out["c"] == features["b"] # 'b' renamed to 'c' + assert_contract_is_typed(out) + + +def test_feature_contract_chained_processors(policy_feature_factory): + # Chain two rename processors at the contract level + processor1 = RenameProcessor(rename_map={"pos": "agent_position", "img": "camera_image"}) + processor2 = RenameProcessor( + rename_map={"agent_position": "observation.state", "camera_image": "observation.image"} + ) + pipeline = RobotProcessor([processor1, processor2]) + + spec = { + "pos": policy_feature_factory(FeatureType.STATE, (7,)), + "img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)), + "extra": policy_feature_factory(FeatureType.ENV, (1,)), + } + out = pipeline.feature_contract(initial_features=spec) + + assert set(out) == {"observation.state", "observation.image", "extra"} + assert out["observation.state"] == spec["pos"] + assert out["observation.image"] == spec["img"] + assert out["extra"] == spec["extra"] + assert_contract_is_typed(out)