mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 08:09:45 +00:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b883328e6c | |||
| 49ecbeb33f | |||
| 88f7bf01c1 | |||
| 6daa579ce1 | |||
| 06bebd97b3 | |||
| e0096feb6a | |||
| 90d3a99aa1 | |||
| 8c577525c1 | |||
| f771e3eaf1 | |||
| 240a3892ae | |||
| 3e24ecaf54 | |||
| 60dc8e3a5d | |||
| dcb305ffb2 | |||
| 11525cedeb | |||
| 2f8d98b05e | |||
| 1baaa77a86 | |||
| 91ed6097bc | |||
| 945e1ff266 | |||
| 71eff183ff | |||
| 67196c9d53 | |||
| 5695432142 | |||
| c14ab9e97b | |||
| c7c3b477d6 | |||
| b267cd40f7 | |||
| 7fe6adaf61 | |||
| 4b88842d20 |
@@ -19,6 +19,11 @@ on:
|
||||
tags:
|
||||
- 'v*.*.*' # Trigger on tags like v0.1.0, v1.0.0
|
||||
|
||||
# Sets up the environment variables
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.10"
|
||||
|
||||
jobs:
|
||||
# This job builds the Python package and publishes it to PyPI
|
||||
build-and-publish:
|
||||
@@ -50,6 +55,7 @@ jobs:
|
||||
VERSION_NUMBER=${VERSION#v}
|
||||
echo "tag_version=$VERSION_NUMBER" >> $GITHUB_OUTPUT
|
||||
- name: Check if version matches pyproject.toml
|
||||
if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-')
|
||||
# zizmor: ignore[template-injection]
|
||||
run: |
|
||||
TAG_VERSION=${{ steps.extract_info.outputs.tag_version }}
|
||||
@@ -86,13 +92,29 @@ jobs:
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
# zizmor: ignore[template-injection]
|
||||
run: gh release create ${{ github.ref_name }} --release-name "Release ${{ github.ref_name }}" --generate-notes ./dist/*
|
||||
run: |
|
||||
gh release create ${{ github.ref_name }} \
|
||||
--title "Release ${{ github.ref_name }}" \
|
||||
--generate-notes \
|
||||
--draft=$([[ "${{ github.ref_name }}" == *-* ]] && echo true || echo false) \
|
||||
--prerelease=$([[ "${{ github.ref_name }}" == *-* ]] && echo true || echo false) \
|
||||
./dist/*
|
||||
|
||||
- name: Publish to PyPI
|
||||
if: startsWith(github.ref, 'refs/tags/v')
|
||||
- name: Publish to TestPyPI for pre-releases
|
||||
# True for tags like 'v0.2.0-rc1'
|
||||
if: startsWith(github.ref, 'refs/tags/v') && contains(github.ref, '-')
|
||||
uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
|
||||
with:
|
||||
password: ${{ secrets.PYPI_API_TOKEN }}
|
||||
repository-url: https://test.pypi.org/legacy/
|
||||
verbose: true
|
||||
print-hash: true
|
||||
|
||||
- name: Publish to PyPI
|
||||
if: startsWith(github.ref, 'refs/tags/v') && !contains(github.ref, '-')
|
||||
uses: pypa/gh-action-pypi-publish@v1.12.4 # zizmor: ignore[unpinned-uses, use-trusted-publishing]
|
||||
with:
|
||||
verbose: true
|
||||
print-hash: true
|
||||
|
||||
# This job runs end-to-end tests on the release
|
||||
test-release:
|
||||
@@ -119,15 +141,31 @@ jobs:
|
||||
enable-cache: true
|
||||
version: ${{ env.UV_VERSION }}
|
||||
python-version: ${{ env.PYTHON_VERSION }}
|
||||
- name: Create uv virtual environment
|
||||
run: uv venv
|
||||
- name: Install lerobot release
|
||||
run: uv run pip install lerobot==${{ needs.build-and-publish.outputs.version }} # zizmor: ignore[template-injection]
|
||||
|
||||
# zizmor: ignore[template-injection]
|
||||
run: |
|
||||
VERSION="${{ needs.build-and-publish.outputs.version }}"
|
||||
if [[ "$VERSION" == *-* ]]; then
|
||||
BASE_VERSION="${VERSION%%-*}"
|
||||
echo "Installing pre-release version $BASE_VERSION from TestPyPI..."
|
||||
uv pip install \
|
||||
--index-url https://test.pypi.org/simple/ \
|
||||
--extra-index-url https://pypi.org/simple \
|
||||
--index-strategy unsafe-best-match \
|
||||
"lerobot[all]==$BASE_VERSION"
|
||||
else
|
||||
echo "Installing release version $VERSION from PyPI..."
|
||||
uv pip install "lerobot[all]==$VERSION"
|
||||
fi
|
||||
- name: Check lerobot version
|
||||
run: uv run lerobot --version
|
||||
run: uv run python -c "import lerobot; print(lerobot.__version__)"
|
||||
|
||||
- name: Run end-to-end tests
|
||||
run: uv run make test-end-to-end
|
||||
|
||||
|
||||
# TODO(Steven): Publish draft/pre-release and to test pypi
|
||||
# TODO(Steven): Publish draft/pre-release and to test pypi weekly
|
||||
# TODO(Steven): Separate build and publish job
|
||||
# TODO(Steven): Tag documentation with the same version as the package
|
||||
|
||||
@@ -1,25 +1,21 @@
|
||||
<p align="center">
|
||||
<picture>
|
||||
<source media="(prefers-color-scheme: dark)" srcset="media/lerobot-logo-thumbnail.png">
|
||||
<source media="(prefers-color-scheme: light)" srcset="media/lerobot-logo-thumbnail.png">
|
||||
<img alt="LeRobot, Hugging Face Robotics Library" src="media/lerobot-logo-thumbnail.png" style="max-width: 100%;">
|
||||
</picture>
|
||||
<img alt="LeRobot, Hugging Face Robotics Library" src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/lerobot-logo-thumbnail.png" width="100%">
|
||||
<br/>
|
||||
<br/>
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
[](https://github.com/huggingface/lerobot/actions/workflows/nightly-tests.yml?query=branch%3Amain)
|
||||
[](https://codecov.io/gh/huggingface/lerobot)
|
||||
[](https://github.com/huggingface/lerobot/actions/workflows/nighty.yml?query=branch%3Amain)
|
||||
[](https://www.python.org/downloads/)
|
||||
[](https://github.com/huggingface/lerobot/blob/main/LICENSE)
|
||||
[](https://pypi.org/project/lerobot/)
|
||||
[](https://pypi.org/project/lerobot/)
|
||||
[](https://github.com/huggingface/lerobot/tree/main/examples)
|
||||
[](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md)
|
||||
[](https://github.com/huggingface/lerobot/blob/main/CODE_OF_CONDUCT.md)
|
||||
[](https://discord.gg/s3KuuzsPFb)
|
||||
|
||||
<!-- [](https://codecov.io/gh/huggingface/lerobot) -->
|
||||
|
||||
</div>
|
||||
|
||||
<h2 align="center">
|
||||
@@ -29,10 +25,10 @@
|
||||
|
||||
<div align="center">
|
||||
<img
|
||||
src="media/hope_jr/hopejr.png?raw=true"
|
||||
src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/hope_jr/hopejr.png"
|
||||
alt="HopeJR robot"
|
||||
title="HopeJR robot"
|
||||
style="width: 60%;"
|
||||
width="60%"
|
||||
/>
|
||||
|
||||
<p><strong>Meet HopeJR – A humanoid robot arm and hand for dexterous manipulation!</strong></p>
|
||||
@@ -51,20 +47,12 @@
|
||||
</h2>
|
||||
|
||||
<div align="center">
|
||||
<div style="display: flex; gap: 1rem; justify-content: center; align-items: center;" >
|
||||
<img
|
||||
src="media/so101/so101.webp?raw=true"
|
||||
alt="SO-101 follower arm"
|
||||
title="SO-101 follower arm"
|
||||
style="width: 40%;"
|
||||
/>
|
||||
<img
|
||||
src="media/so101/so101-leader.webp?raw=true"
|
||||
alt="SO-101 leader arm"
|
||||
title="SO-101 leader arm"
|
||||
style="width: 40%;"
|
||||
/>
|
||||
</div>
|
||||
<table>
|
||||
<tr>
|
||||
<td align="center"><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/so101/so101.webp" alt="SO-101 follower arm" title="SO-101 follower arm" width="90%"/></td>
|
||||
<td align="center"><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/so101/so101-leader.webp" alt="SO-101 leader arm" title="SO-101 leader arm" width="90%"/></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
<p><strong>Meet the updated SO100, the SO-101 – Just €114 per arm!</strong></p>
|
||||
<p>Train it in minutes with a few simple moves on your laptop.</p>
|
||||
@@ -76,7 +64,7 @@
|
||||
<p>Want to take it to the next level? Make your SO-101 mobile by building LeKiwi!</p>
|
||||
<p>Check out the <a href="https://huggingface.co/docs/lerobot/lekiwi">LeKiwi tutorial</a> and bring your robot to life on wheels.</p>
|
||||
|
||||
<img src="media/lekiwi/kiwi.webp?raw=true" alt="LeKiwi mobile robot" title="LeKiwi mobile robot" width="50%">
|
||||
<img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/lekiwi/kiwi.webp" alt="LeKiwi mobile robot" title="LeKiwi mobile robot" width="50%">
|
||||
</div>
|
||||
|
||||
<br/>
|
||||
@@ -99,9 +87,9 @@
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td><img src="media/gym/aloha_act.gif" width="100%" alt="ACT policy on ALOHA env"/></td>
|
||||
<td><img src="media/gym/simxarm_tdmpc.gif" width="100%" alt="TDMPC policy on SimXArm env"/></td>
|
||||
<td><img src="media/gym/pusht_diffusion.gif" width="100%" alt="Diffusion policy on PushT env"/></td>
|
||||
<td><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/gym/aloha_act.gif" width="100%" alt="ACT policy on ALOHA env"/></td>
|
||||
<td><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/gym/simxarm_tdmpc.gif" width="100%" alt="TDMPC policy on SimXArm env"/></td>
|
||||
<td><img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/gym/pusht_diffusion.gif" width="100%" alt="Diffusion policy on PushT env"/></td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center">ACT policy on ALOHA env</td>
|
||||
@@ -110,23 +98,11 @@
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
### Acknowledgment
|
||||
|
||||
- The LeRobot team 🤗 for building SmolVLA [Paper](https://arxiv.org/abs/2506.01844), [Blog](https://huggingface.co/blog/smolvla).
|
||||
- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
|
||||
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
|
||||
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
|
||||
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
|
||||
- Thanks to [Seungjae (Jay) Lee](https://sjlee.cc/), [Mahi Shafiullah](https://mahis.life/) and colleagues for open sourcing [VQ-BeT](https://sjlee.cc/vq-bet/) policy and helping us adapt the codebase to our repository. The policy is adapted from [VQ-BeT repo](https://github.com/jayLEE0301/vq_bet_official).
|
||||
|
||||
## Installation
|
||||
|
||||
Download our source code:
|
||||
LeRobot works with Python 3.10+ and PyTorch 2.2+.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
```
|
||||
### Environment Setup
|
||||
|
||||
Create a virtual environment with Python 3.10 and activate it, e.g. with [`miniconda`](https://docs.anaconda.com/free/miniconda/index.html):
|
||||
|
||||
@@ -151,7 +127,18 @@ conda install ffmpeg -c conda-forge
|
||||
>
|
||||
> - _[On Linux only]_ Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
|
||||
Install 🤗 LeRobot:
|
||||
### Install LeRobot 🤗
|
||||
|
||||
#### From Source
|
||||
|
||||
First, clone the repository and navigate into the directory:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
```
|
||||
|
||||
Then, install the library in editable mode. This is useful if you plan to contribute to the code.
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
@@ -172,6 +159,34 @@ For instance, to install 🤗 LeRobot with aloha and pusht, use:
|
||||
pip install -e ".[aloha, pusht]"
|
||||
```
|
||||
|
||||
### Installation from PyPI
|
||||
|
||||
**Core Library:**
|
||||
Install the base package with:
|
||||
|
||||
```bash
|
||||
pip install lerobot
|
||||
```
|
||||
|
||||
_This installs only the default dependencies._
|
||||
|
||||
**Extra Features:**
|
||||
To install additional functionality, use one of the following:
|
||||
|
||||
```bash
|
||||
pip install 'lerobot[all]' # All available features
|
||||
pip install 'lerobot[aloha,pusht]' # Specific features (Aloha & Pusht)
|
||||
pip install 'lerobot[feetech]' # Feetech motor support
|
||||
```
|
||||
|
||||
_Replace `[...]` with your desired features._
|
||||
|
||||
**Available Tags:**
|
||||
For a full list of optional dependencies, see:
|
||||
https://pypi.org/project/lerobot/
|
||||
|
||||
### Weights & Biases
|
||||
|
||||
To use [Weights and Biases](https://docs.wandb.ai/quickstart) for experiment tracking, log in with
|
||||
|
||||
```bash
|
||||
@@ -182,7 +197,7 @@ wandb login
|
||||
|
||||
### Visualize datasets
|
||||
|
||||
Check out [example 1](./examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub.
|
||||
Check out [example 1](https://github.com/huggingface/lerobot/blob/main/examples/1_load_lerobot_dataset.py) that illustrates how to use our dataset class which automatically downloads data from the Hugging Face hub.
|
||||
|
||||
You can also locally visualize episodes from a dataset on the hub by executing our script from the command line:
|
||||
|
||||
@@ -212,7 +227,7 @@ Our script can also visualize datasets stored on a distant server. See `python -
|
||||
|
||||
A dataset in `LeRobotDataset` format is very simple to use. It can be loaded from a repository on the Hugging Face hub or a local folder simply with e.g. `dataset = LeRobotDataset("lerobot/aloha_static_coffee")` and can be indexed into like any Hugging Face and PyTorch dataset. For instance `dataset[0]` will retrieve a single temporal frame from the dataset containing observation(s) and an action as PyTorch tensors ready to be fed to a model.
|
||||
|
||||
A specificity of `LeRobotDataset` is that, rather than retrieving a single frame by its index, we can retrieve several frames based on their temporal relationship with the indexed frame, by setting `delta_timestamps` to a list of relative times with respect to the indexed frame. For example, with `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for a given index, 4 frames: 3 "previous" frames 1 second, 0.5 seconds, and 0.2 seconds before the indexed frame, and the indexed frame itself (corresponding to the 0 entry). See example [1_load_lerobot_dataset.py](examples/1_load_lerobot_dataset.py) for more details on `delta_timestamps`.
|
||||
A specificity of `LeRobotDataset` is that, rather than retrieving a single frame by its index, we can retrieve several frames based on their temporal relationship with the indexed frame, by setting `delta_timestamps` to a list of relative times with respect to the indexed frame. For example, with `delta_timestamps = {"observation.image": [-1, -0.5, -0.2, 0]}` one can retrieve, for a given index, 4 frames: 3 "previous" frames 1 second, 0.5 seconds, and 0.2 seconds before the indexed frame, and the indexed frame itself (corresponding to the 0 entry). See example [1_load_lerobot_dataset.py](https://github.com/huggingface/lerobot/blob/main/examples/1_load_lerobot_dataset.py) for more details on `delta_timestamps`.
|
||||
|
||||
Under the hood, the `LeRobotDataset` format makes use of several ways to serialize data which can be useful to understand if you plan to work more closely with this format. We tried to make a flexible yet simple dataset format that would cover most type of features and specificities present in reinforcement learning and robotics, in simulation and in real-world, with a focus on cameras and robot states but easily extended to other types of sensory inputs as long as they can be represented by a tensor.
|
||||
|
||||
@@ -256,7 +271,7 @@ Dataset can be uploaded/downloaded from the HuggingFace hub seamlessly. To work
|
||||
|
||||
### Evaluate a pretrained policy
|
||||
|
||||
Check out [example 2](./examples/2_evaluate_pretrained_policy.py) that illustrates how to download a pretrained policy from Hugging Face hub, and run an evaluation on its corresponding environment.
|
||||
Check out [example 2](https://github.com/huggingface/lerobot/blob/main/examples/2_evaluate_pretrained_policy.py) that illustrates how to download a pretrained policy from Hugging Face hub, and run an evaluation on its corresponding environment.
|
||||
|
||||
We also provide a more capable script to parallelize the evaluation over multiple environments during the same rollout. Here is an example with a pretrained model hosted on [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht):
|
||||
|
||||
@@ -280,13 +295,13 @@ See `python -m lerobot.scripts.eval --help` for more instructions.
|
||||
|
||||
### Train your own policy
|
||||
|
||||
Check out [example 3](./examples/3_train_policy.py) that illustrates how to train a model using our core library in python, and [example 4](./examples/4_train_policy_with_script.md) that shows how to use our training script from command line.
|
||||
Check out [example 3](https://github.com/huggingface/lerobot/blob/main/examples/3_train_policy.py) that illustrates how to train a model using our core library in python, and [example 4](https://github.com/huggingface/lerobot/blob/main/examples/4_train_policy_with_script.md) that shows how to use our training script from command line.
|
||||
|
||||
To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding `--wandb.enable=true`.
|
||||
|
||||
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](./examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explanation of some commonly used metrics in logs.
|
||||
A link to the wandb logs for the run will also show up in yellow in your terminal. Here is an example of what they look like in your browser. Please also check [here](https://github.com/huggingface/lerobot/blob/main/examples/4_train_policy_with_script.md#typical-logs-and-metrics) for the explanation of some commonly used metrics in logs.
|
||||
|
||||

|
||||
\<img src="https://raw.githubusercontent.com/huggingface/lerobot/main/media/wandb.png" alt="WandB logs example"\>
|
||||
|
||||
Note: For efficiency, during training every checkpoint is evaluated on a low number of episodes. You may use `--eval.n_episodes=500` to evaluate on more episodes than the default. Or, after training, you may want to re-evaluate your best checkpoints on more episodes or change the evaluation settings. See `python -m lerobot.scripts.eval --help` for more instructions.
|
||||
|
||||
@@ -305,26 +320,6 @@ reproduces SOTA results for Diffusion Policy on the PushT task.
|
||||
|
||||
If you would like to contribute to 🤗 LeRobot, please check out our [contribution guide](https://github.com/huggingface/lerobot/blob/main/CONTRIBUTING.md).
|
||||
|
||||
<!-- ### Add a new dataset
|
||||
|
||||
To add a dataset to the hub, you need to login using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Then point to your raw dataset folder (e.g. `data/aloha_static_pingpong_test_raw`), and push your dataset to the hub with:
|
||||
```bash
|
||||
python lerobot/scripts/push_dataset_to_hub.py \
|
||||
--raw-dir data/aloha_static_pingpong_test_raw \
|
||||
--out-dir data \
|
||||
--repo-id lerobot/aloha_static_pingpong_test \
|
||||
--raw-format aloha_hdf5
|
||||
```
|
||||
|
||||
See `python lerobot/scripts/push_dataset_to_hub.py --help` for more instructions.
|
||||
|
||||
If your dataset format is not supported, implement your own in `lerobot/datasets/push_dataset_to_hub/${raw_format}_format.py` by copying examples like [pusht_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/datasets/push_dataset_to_hub/pusht_zarr_format.py), [umi_zarr](https://github.com/huggingface/lerobot/blob/main/lerobot/datasets/push_dataset_to_hub/umi_zarr_format.py), [aloha_hdf5](https://github.com/huggingface/lerobot/blob/main/lerobot/datasets/push_dataset_to_hub/aloha_hdf5_format.py), or [xarm_pkl](https://github.com/huggingface/lerobot/blob/main/lerobot/datasets/push_dataset_to_hub/xarm_pkl_format.py). -->
|
||||
|
||||
### Add a pretrained policy
|
||||
|
||||
Once you have trained a policy you may upload it to the Hugging Face hub using a hub id that looks like `${hf_user}/${repo_name}` (e.g. [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht)).
|
||||
@@ -341,34 +336,16 @@ To upload these to the hub, run the following:
|
||||
huggingface-cli upload ${hf_user}/${repo_name} path/to/pretrained_model
|
||||
```
|
||||
|
||||
See [eval.py](https://github.com/huggingface/lerobot/blob/main/lerobot/scripts/eval.py) for an example of how other people may use your policy.
|
||||
See [eval.py](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/eval.py) for an example of how other people may use your policy.
|
||||
|
||||
### Improve your code with profiling
|
||||
### Acknowledgment
|
||||
|
||||
An example of a code snippet to profile the evaluation of a policy:
|
||||
|
||||
<!-- prettier-ignore-start -->
|
||||
```python
|
||||
from torch.profiler import profile, record_function, ProfilerActivity
|
||||
|
||||
def trace_handler(prof):
|
||||
prof.export_chrome_trace(f"tmp/trace_schedule_{prof.step_num}.json")
|
||||
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
schedule=torch.profiler.schedule(
|
||||
wait=2,
|
||||
warmup=2,
|
||||
active=3,
|
||||
),
|
||||
on_trace_ready=trace_handler
|
||||
) as prof:
|
||||
with record_function("eval_policy"):
|
||||
for i in range(num_episodes):
|
||||
prof.step()
|
||||
# insert code to profile, potentially whole body of eval_policy function
|
||||
```
|
||||
<!-- prettier-ignore-end -->
|
||||
- The LeRobot team 🤗 for building SmolVLA [Paper](https://arxiv.org/abs/2506.01844), [Blog](https://huggingface.co/blog/smolvla).
|
||||
- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
|
||||
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
|
||||
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
|
||||
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
|
||||
- Thanks to [Seungjae (Jay) Lee](https://sjlee.cc/), [Mahi Shafiullah](https://mahis.life/) and colleagues for open sourcing [VQ-BeT](https://sjlee.cc/vq-bet/) policy and helping us adapt the codebase to our repository. The policy is adapted from [VQ-BeT repo](https://github.com/jayLEE0301/vq_bet_official).
|
||||
|
||||
## Citation
|
||||
|
||||
@@ -376,83 +353,13 @@ If you want, you can cite this work with:
|
||||
|
||||
```bibtex
|
||||
@misc{cadene2024lerobot,
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascale, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
|
||||
author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascal, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas},
|
||||
title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch},
|
||||
howpublished = "\url{https://github.com/huggingface/lerobot}",
|
||||
year = {2024}
|
||||
}
|
||||
```
|
||||
|
||||
Additionally, if you are using any of the particular policy architecture, pretrained models, or datasets, it is recommended to cite the original authors of the work as they appear below:
|
||||
|
||||
- [SmolVLA](https://arxiv.org/abs/2506.01844)
|
||||
|
||||
```bibtex
|
||||
@article{shukor2025smolvla,
|
||||
title={SmolVLA: A Vision-Language-Action Model for Affordable and Efficient Robotics},
|
||||
author={Shukor, Mustafa and Aubakirova, Dana and Capuano, Francesco and Kooijmans, Pepijn and Palma, Steven and Zouitine, Adil and Aractingi, Michel and Pascal, Caroline and Russi, Martino and Marafioti, Andres and Alibert, Simon and Cord, Matthieu and Wolf, Thomas and Cadene, Remi},
|
||||
journal={arXiv preprint arXiv:2506.01844},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
|
||||
- [Diffusion Policy](https://diffusion-policy.cs.columbia.edu)
|
||||
|
||||
```bibtex
|
||||
@article{chi2024diffusionpolicy,
|
||||
author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song},
|
||||
title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
|
||||
journal = {The International Journal of Robotics Research},
|
||||
year = {2024},
|
||||
}
|
||||
```
|
||||
|
||||
- [ACT or ALOHA](https://tonyzhaozh.github.io/aloha)
|
||||
|
||||
```bibtex
|
||||
@article{zhao2023learning,
|
||||
title={Learning fine-grained bimanual manipulation with low-cost hardware},
|
||||
author={Zhao, Tony Z and Kumar, Vikash and Levine, Sergey and Finn, Chelsea},
|
||||
journal={arXiv preprint arXiv:2304.13705},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
||||
- [TDMPC](https://www.nicklashansen.com/td-mpc/)
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Hansen2022tdmpc,
|
||||
title={Temporal Difference Learning for Model Predictive Control},
|
||||
author={Nicklas Hansen and Xiaolong Wang and Hao Su},
|
||||
booktitle={ICML},
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
|
||||
- [VQ-BeT](https://sjlee.cc/vq-bet/)
|
||||
|
||||
```bibtex
|
||||
@article{lee2024behavior,
|
||||
title={Behavior generation with latent actions},
|
||||
author={Lee, Seungjae and Wang, Yibin and Etukuru, Haritheja and Kim, H Jin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel},
|
||||
journal={arXiv preprint arXiv:2403.03181},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
|
||||
- [HIL-SERL](https://hil-serl.github.io/)
|
||||
|
||||
```bibtex
|
||||
@Article{luo2024hilserl,
|
||||
title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning},
|
||||
author={Jianlan Luo and Charles Xu and Jeffrey Wu and Sergey Levine},
|
||||
year={2024},
|
||||
eprint={2410.21845},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.RO}
|
||||
}
|
||||
```
|
||||
|
||||
## Star History
|
||||
|
||||
[](https://star-history.com/#huggingface/lerobot&Timeline)
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
# docs-requirements.txt
|
||||
hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main
|
||||
watchdog>=6.0.0
|
||||
+1
-1
@@ -20,7 +20,7 @@ To generate the documentation, you first have to build it. Several packages are
|
||||
you can install them with the following command, at the root of the code repository:
|
||||
|
||||
```bash
|
||||
pip install -e ".[docs]"
|
||||
pip install -e . -r docs-requirements.txt
|
||||
```
|
||||
|
||||
You will also need `nodejs`. Please refer to their [installation page](https://nodejs.org/en/download)
|
||||
|
||||
@@ -294,7 +294,7 @@ dataset.push_to_hub()
|
||||
|
||||
#### Dataset upload
|
||||
|
||||
Locally, your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}`. At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. https://huggingface.co/datasets/cadene/so101_test) that you can obtain by running:
|
||||
Locally, your dataset is stored in this folder: `~/.cache/huggingface/lerobot/{repo-id}`. At the end of data recording, your dataset will be uploaded on your Hugging Face page (e.g. `https://huggingface.co/datasets/${HF_USER}/so101_test`) that you can obtain by running:
|
||||
|
||||
```bash
|
||||
echo https://huggingface.co/datasets/${HF_USER}/so101_test
|
||||
@@ -428,7 +428,7 @@ Your robot should replicate movements similar to those you recorded. For example
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
@@ -444,7 +444,7 @@ python -m lerobot.scripts.train \
|
||||
Let's explain the command:
|
||||
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/so101_test`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
3. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
4. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
@@ -462,9 +462,9 @@ If you do not want to push your model to the hub after training use `--policy.pu
|
||||
|
||||
Additionally you can provide extra `tags` or specify a `license` for your model or make the model repo `private` by adding this: `--policy.private=true --policy.tags=\[ppo,rl\] --policy.license=mit`
|
||||
|
||||
#### Train using Collab
|
||||
#### Train using Google Colab
|
||||
|
||||
If your local computer doesn't have a powerful GPU you could utilize Google Collab to train your model by following the [ACT training notebook](./notebooks#training-act).
|
||||
If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
|
||||
|
||||
#### Upload policy checkpoints
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ If you uploaded your dataset to the hub you can [visualize your dataset online](
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
|
||||
```bash
|
||||
python -m lerobot.scripts.train \
|
||||
@@ -111,7 +111,7 @@ python -m lerobot.scripts.train \
|
||||
Let's explain the command:
|
||||
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/il_gym`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
3. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
4. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
|
||||
@@ -1,15 +1,6 @@
|
||||
# Installation
|
||||
|
||||
## Install LeRobot
|
||||
|
||||
Currently only available from source.
|
||||
|
||||
Download our source code:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
```
|
||||
## Environment Setup
|
||||
|
||||
Create a virtual environment with Python 3.10, using [`Miniconda`](https://docs.anaconda.com/miniconda/install/#quick-command-line-install)
|
||||
|
||||
@@ -40,12 +31,49 @@ conda install ffmpeg -c conda-forge
|
||||
>
|
||||
> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
|
||||
Install 🤗 LeRobot:
|
||||
## Install LeRobot 🤗
|
||||
|
||||
### From Source
|
||||
|
||||
First, clone the repository and navigate into the directory:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
```
|
||||
|
||||
Then, install the library in editable mode. This is useful if you plan to contribute to the code.
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
### Installation from PyPI
|
||||
|
||||
**Core Library:**
|
||||
Install the base package with:
|
||||
|
||||
```bash
|
||||
pip install lerobot
|
||||
```
|
||||
|
||||
_This installs only the default dependencies._
|
||||
|
||||
**Extra Features:**
|
||||
To install additional functionality, use one of the following:
|
||||
|
||||
```bash
|
||||
pip install 'lerobot[all]' # All available features
|
||||
pip install 'lerobot[aloha,pusht]' # Specific features (Aloha & Pusht)
|
||||
pip install 'lerobot[feetech]' # Feetech motor support
|
||||
```
|
||||
|
||||
_Replace `[...]` with your desired features._
|
||||
|
||||
**Available Tags:**
|
||||
For a full list of optional dependencies, see:
|
||||
https://pypi.org/project/lerobot/
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
If you encounter build errors, you may need to install additional dependencies: `cmake`, `build-essential`, and `ffmpeg libs`.
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
## Paper
|
||||
|
||||
https://tonyzhaozh.github.io/aloha
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{zhao2023learning,
|
||||
title={Learning fine-grained bimanual manipulation with low-cost hardware},
|
||||
author={Zhao, Tony Z and Kumar, Vikash and Levine, Sergey and Finn, Chelsea},
|
||||
journal={arXiv preprint arXiv:2304.13705},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,14 @@
|
||||
## Paper
|
||||
|
||||
https://diffusion-policy.cs.columbia.edu
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{chi2024diffusionpolicy,
|
||||
author = {Cheng Chi and Zhenjia Xu and Siyuan Feng and Eric Cousineau and Yilun Du and Benjamin Burchfiel and Russ Tedrake and Shuran Song},
|
||||
title ={Diffusion Policy: Visuomotor Policy Learning via Action Diffusion},
|
||||
journal = {The International Journal of Robotics Research},
|
||||
year = {2024},
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,14 @@
|
||||
## Paper
|
||||
|
||||
https://arxiv.org/abs/2506.01844
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{shukor2025smolvla,
|
||||
title={SmolVLA: A Vision-Language-Action Model for Affordable and Efficient Robotics},
|
||||
author={Shukor, Mustafa and Aubakirova, Dana and Capuano, Francesco and Kooijmans, Pepijn and Palma, Steven and Zouitine, Adil and Aractingi, Michel and Pascal, Caroline and Russi, Martino and Marafioti, Andres and Alibert, Simon and Cord, Matthieu and Wolf, Thomas and Cadene, Remi},
|
||||
journal={arXiv preprint arXiv:2506.01844},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,14 @@
|
||||
## Paper
|
||||
|
||||
https://www.nicklashansen.com/td-mpc/
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@inproceedings{Hansen2022tdmpc,
|
||||
title={Temporal Difference Learning for Model Predictive Control},
|
||||
author={Nicklas Hansen and Xiaolong Wang and Hao Su},
|
||||
booktitle={ICML},
|
||||
year={2022}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,14 @@
|
||||
## Paper
|
||||
|
||||
https://sjlee.cc/vq-bet/
|
||||
|
||||
## Citation
|
||||
|
||||
```bibtex
|
||||
@article{lee2024behavior,
|
||||
title={Behavior generation with latent actions},
|
||||
author={Lee, Seungjae and Wang, Yibin and Etukuru, Haritheja and Kim, H Jin and Shafiullah, Nur Muhammad Mahi and Pinto, Lerrel},
|
||||
journal={arXiv preprint arXiv:2403.03181},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
+6
-7
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.2.0"
|
||||
version = "0.3.3"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
readme = "README.md"
|
||||
license = { text = "Apache-2.0" }
|
||||
@@ -61,22 +61,23 @@ dependencies = [
|
||||
# Hugging Face dependencies
|
||||
"datasets>=2.19.0,<=3.6.0", # TODO: Bumb dependency
|
||||
"diffusers>=0.27.2",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.27.1,<0.34.0",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.34.2",
|
||||
|
||||
# Core dependencies
|
||||
"cmake>=3.29.0.1",
|
||||
"einops>=0.8.0",
|
||||
"opencv-python-headless>=4.9.0",
|
||||
"av>=14.2.0",
|
||||
"torch>=2.2.1",
|
||||
"torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')",
|
||||
"torchvision>=0.21.0",
|
||||
"jsonlines>=4.0.0",
|
||||
"packaging>=24.2",
|
||||
"pynput>=1.7.7",
|
||||
"pyserial>=3.5",
|
||||
"wandb>=0.20.0",
|
||||
|
||||
"torch>=2.2.1,<2.8.0", # TODO: Bumb dependency
|
||||
"torchcodec>=0.2.1,<0.6.0; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", # TODO: Bumb dependency
|
||||
"torchvision>=0.21.0,<0.23.0", # TODO: Bumb dependency
|
||||
|
||||
"draccus==0.10.0", # TODO: Remove ==
|
||||
"gymnasium>=0.29.1,<1.0.0", # TODO: Bumb dependency
|
||||
"rerun-sdk>=0.21.0,<0.23.0", # TODO: Bumb dependency
|
||||
@@ -125,7 +126,6 @@ hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.9", "lerobot[grpcio-dep]",
|
||||
async = ["lerobot[grpcio-dep]", "matplotlib>=3.10.3"]
|
||||
|
||||
# Development
|
||||
docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"]
|
||||
dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1", "lerobot[grpcio-dep]", "grpcio-tools==1.73.1"]
|
||||
test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "mock-serial>=0.0.1 ; sys_platform != 'win32'"]
|
||||
video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"]
|
||||
@@ -147,7 +147,6 @@ all = [
|
||||
"lerobot[smolvla]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[async]",
|
||||
"lerobot[docs]",
|
||||
"lerobot[dev]",
|
||||
"lerobot[test]",
|
||||
"lerobot[video_benchmark]",
|
||||
|
||||
@@ -0,0 +1,625 @@
|
||||
# This file is autogenerated by pip-compile with Python 3.10
|
||||
# by the following command:
|
||||
#
|
||||
# pip-compile --output-file=requirements-macos.txt requirements.in
|
||||
#
|
||||
-e .[all]
|
||||
# via -[all]
|
||||
absl-py==2.3.1
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# dm-tree
|
||||
# labmaze
|
||||
# mujoco
|
||||
accelerate==1.9.0
|
||||
# via lerobot
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.12.15
|
||||
# via fsspec
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
asttokens==3.0.0
|
||||
# via stack-data
|
||||
async-timeout==5.0.1
|
||||
# via aiohttp
|
||||
attrs==25.3.0
|
||||
# via
|
||||
# aiohttp
|
||||
# dm-tree
|
||||
# jsonlines
|
||||
# rerun-sdk
|
||||
av==15.0.0
|
||||
# via lerobot
|
||||
blinker==1.9.0
|
||||
# via flask
|
||||
certifi==2025.7.14
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
cffi==1.17.1
|
||||
# via pymunk
|
||||
cfgv==3.4.0
|
||||
# via pre-commit
|
||||
charset-normalizer==3.4.2
|
||||
# via requests
|
||||
click==8.2.1
|
||||
# via
|
||||
# flask
|
||||
# wandb
|
||||
cloudpickle==3.1.1
|
||||
# via gymnasium
|
||||
cmake==4.0.3
|
||||
# via lerobot
|
||||
cmeel==0.57.3
|
||||
# via
|
||||
# cmeel-assimp
|
||||
# cmeel-boost
|
||||
# cmeel-console-bridge
|
||||
# cmeel-octomap
|
||||
# cmeel-qhull
|
||||
# cmeel-tinyxml2
|
||||
# cmeel-urdfdom
|
||||
# cmeel-zlib
|
||||
# coal-library
|
||||
# eigenpy
|
||||
# eiquadprog
|
||||
# pin
|
||||
# placo
|
||||
# rhoban-cmeel-jsoncpp
|
||||
cmeel-assimp==5.4.3.1
|
||||
# via coal-library
|
||||
cmeel-boost==1.87.0.1
|
||||
# via
|
||||
# coal-library
|
||||
# eigenpy
|
||||
# eiquadprog
|
||||
# pin
|
||||
cmeel-console-bridge==1.0.2.3
|
||||
# via cmeel-urdfdom
|
||||
cmeel-octomap==1.10.0
|
||||
# via coal-library
|
||||
cmeel-qhull==8.0.2.1
|
||||
# via coal-library
|
||||
cmeel-tinyxml2==10.0.0
|
||||
# via cmeel-urdfdom
|
||||
cmeel-urdfdom==4.0.1
|
||||
# via pin
|
||||
cmeel-zlib==1.3.1
|
||||
# via cmeel-assimp
|
||||
coal-library==3.0.1
|
||||
# via pin
|
||||
contourpy==1.3.2
|
||||
# via matplotlib
|
||||
coverage[toml]==7.10.1
|
||||
# via pytest-cov
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
datasets==3.6.0
|
||||
# via lerobot
|
||||
debugpy==1.8.15
|
||||
# via lerobot
|
||||
decorator==5.2.1
|
||||
# via ipython
|
||||
deepdiff==8.5.0
|
||||
# via lerobot
|
||||
diffusers==0.34.0
|
||||
# via lerobot
|
||||
dill==0.3.8
|
||||
# via
|
||||
# datasets
|
||||
# multiprocess
|
||||
distlib==0.4.0
|
||||
# via virtualenv
|
||||
dm-control==1.0.14
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
dm-tree==0.1.9
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
docopt==0.6.2
|
||||
# via num2words
|
||||
draccus==0.10.0
|
||||
# via lerobot
|
||||
dynamixel-sdk==3.7.31
|
||||
# via lerobot
|
||||
eigenpy==3.10.3
|
||||
# via coal-library
|
||||
einops==0.8.1
|
||||
# via lerobot
|
||||
eiquadprog==1.2.9
|
||||
# via placo
|
||||
exceptiongroup==1.3.0
|
||||
# via
|
||||
# ipython
|
||||
# pytest
|
||||
executing==2.2.0
|
||||
# via stack-data
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
feetech-servo-sdk==1.0.0
|
||||
# via lerobot
|
||||
filelock==3.18.0
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
# huggingface-hub
|
||||
# torch
|
||||
# transformers
|
||||
# virtualenv
|
||||
flask==3.1.1
|
||||
# via lerobot
|
||||
fonttools==4.59.0
|
||||
# via matplotlib
|
||||
frozenlist==1.7.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2025.3.0
|
||||
# via
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# torch
|
||||
gitdb==4.0.12
|
||||
# via gitpython
|
||||
gitpython==3.1.45
|
||||
# via wandb
|
||||
glfw==2.9.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
grpcio==1.73.1
|
||||
# via
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
grpcio-tools==1.73.1
|
||||
# via lerobot
|
||||
gym-aloha==0.1.1
|
||||
# via lerobot
|
||||
gym-hil==0.1.10
|
||||
# via lerobot
|
||||
gym-pusht==0.1.5
|
||||
# via lerobot
|
||||
gym-xarm==0.1.1
|
||||
# via lerobot
|
||||
gymnasium==0.29.1
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-pusht
|
||||
# gym-xarm
|
||||
# gymnasium-robotics
|
||||
# lerobot
|
||||
# pettingzoo
|
||||
gymnasium-robotics==1.2.4
|
||||
# via gym-xarm
|
||||
hf-transfer==0.1.9
|
||||
# via huggingface-hub
|
||||
hf-xet==1.1.5
|
||||
# via huggingface-hub
|
||||
hidapi==0.14.0.post4
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
huggingface-hub[cli,hf-transfer]==0.34.3
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# diffusers
|
||||
# lerobot
|
||||
# tokenizers
|
||||
# transformers
|
||||
identify==2.6.12
|
||||
# via pre-commit
|
||||
idna==3.10
|
||||
# via
|
||||
# requests
|
||||
# yarl
|
||||
imageio[ffmpeg]==2.37.0
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gymnasium-robotics
|
||||
# lerobot
|
||||
# scikit-image
|
||||
imageio-ffmpeg==0.6.0
|
||||
# via imageio
|
||||
importlib-metadata==8.7.0
|
||||
# via diffusers
|
||||
iniconfig==2.1.0
|
||||
# via pytest
|
||||
inquirerpy==0.3.4
|
||||
# via huggingface-hub
|
||||
ipython==8.37.0
|
||||
# via meshcat
|
||||
ischedule==1.2.7
|
||||
# via placo
|
||||
itsdangerous==2.2.0
|
||||
# via flask
|
||||
jedi==0.19.2
|
||||
# via ipython
|
||||
jinja2==3.1.6
|
||||
# via
|
||||
# flask
|
||||
# gymnasium-robotics
|
||||
# torch
|
||||
jsonlines==4.0.0
|
||||
# via lerobot
|
||||
kiwisolver==1.4.8
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lazy-loader==0.4
|
||||
# via scikit-image
|
||||
lxml==6.0.0
|
||||
# via dm-control
|
||||
markupsafe==3.0.2
|
||||
# via
|
||||
# flask
|
||||
# jinja2
|
||||
# werkzeug
|
||||
matplotlib==3.10.5
|
||||
# via lerobot
|
||||
matplotlib-inline==0.1.7
|
||||
# via ipython
|
||||
mergedeep==1.3.4
|
||||
# via draccus
|
||||
meshcat==0.3.2
|
||||
# via placo
|
||||
mock-serial==0.0.1
|
||||
# via lerobot
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
mujoco==2.3.7
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-xarm
|
||||
# gymnasium-robotics
|
||||
multidict==6.6.3
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
multiprocess==0.70.16
|
||||
# via datasets
|
||||
mypy-extensions==1.1.0
|
||||
# via typing-inspect
|
||||
networkx==3.4.2
|
||||
# via
|
||||
# scikit-image
|
||||
# torch
|
||||
nodeenv==1.9.1
|
||||
# via pre-commit
|
||||
num2words==0.5.14
|
||||
# via lerobot
|
||||
numpy==2.2.6
|
||||
# via
|
||||
# accelerate
|
||||
# cmeel-boost
|
||||
# contourpy
|
||||
# datasets
|
||||
# diffusers
|
||||
# dm-control
|
||||
# dm-env
|
||||
# dm-tree
|
||||
# gymnasium
|
||||
# gymnasium-robotics
|
||||
# imageio
|
||||
# labmaze
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# mujoco
|
||||
# opencv-python
|
||||
# opencv-python-headless
|
||||
# pandas
|
||||
# pettingzoo
|
||||
# rerun-sdk
|
||||
# scikit-image
|
||||
# scipy
|
||||
# shapely
|
||||
# tifffile
|
||||
# torchvision
|
||||
# transformers
|
||||
opencv-python==4.12.0.88
|
||||
# via gym-pusht
|
||||
opencv-python-headless==4.12.0.88
|
||||
# via lerobot
|
||||
orderly-set==5.5.0
|
||||
# via deepdiff
|
||||
packaging==25.0
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# lazy-loader
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# pytest
|
||||
# scikit-image
|
||||
# transformers
|
||||
# wandb
|
||||
pandas==2.3.1
|
||||
# via
|
||||
# datasets
|
||||
# lerobot
|
||||
parso==0.8.4
|
||||
# via jedi
|
||||
pettingzoo==1.24.3
|
||||
# via gymnasium-robotics
|
||||
pexpect==4.9.0
|
||||
# via ipython
|
||||
pfzy==0.3.4
|
||||
# via inquirerpy
|
||||
pillow==11.3.0
|
||||
# via
|
||||
# diffusers
|
||||
# imageio
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# rerun-sdk
|
||||
# scikit-image
|
||||
# torchvision
|
||||
pin==3.4.0
|
||||
# via placo
|
||||
placo==0.9.14
|
||||
# via lerobot
|
||||
platformdirs==4.3.8
|
||||
# via
|
||||
# virtualenv
|
||||
# wandb
|
||||
pluggy==1.6.0
|
||||
# via
|
||||
# pytest
|
||||
# pytest-cov
|
||||
pre-commit==4.2.0
|
||||
# via lerobot
|
||||
prompt-toolkit==3.0.51
|
||||
# via
|
||||
# inquirerpy
|
||||
# ipython
|
||||
propcache==0.3.2
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
protobuf==6.31.0
|
||||
# via
|
||||
# dm-control
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
# wandb
|
||||
psutil==7.0.0
|
||||
# via
|
||||
# accelerate
|
||||
# imageio
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pure-eval==0.2.3
|
||||
# via stack-data
|
||||
pyarrow==21.0.0
|
||||
# via
|
||||
# datasets
|
||||
# rerun-sdk
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pydantic==2.11.7
|
||||
# via wandb
|
||||
pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pygame==2.6.1
|
||||
# via
|
||||
# gym-hil
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
pygments==2.19.2
|
||||
# via
|
||||
# ipython
|
||||
# pytest
|
||||
pymunk==6.11.1
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
pyngrok==7.2.12
|
||||
# via meshcat
|
||||
pynput==1.8.1
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
pyobjc-core==11.1
|
||||
# via
|
||||
# pyobjc-framework-applicationservices
|
||||
# pyobjc-framework-cocoa
|
||||
# pyobjc-framework-coretext
|
||||
# pyobjc-framework-quartz
|
||||
pyobjc-framework-applicationservices==11.1
|
||||
# via pynput
|
||||
pyobjc-framework-cocoa==11.1
|
||||
# via
|
||||
# pyobjc-framework-applicationservices
|
||||
# pyobjc-framework-coretext
|
||||
# pyobjc-framework-quartz
|
||||
pyobjc-framework-coretext==11.1
|
||||
# via pyobjc-framework-applicationservices
|
||||
pyobjc-framework-quartz==11.1
|
||||
# via
|
||||
# pynput
|
||||
# pyobjc-framework-applicationservices
|
||||
# pyobjc-framework-coretext
|
||||
pyopengl==3.1.9
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.2.3
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
pyrealsense2-macosx==2.54.2
|
||||
# via lerobot
|
||||
pyserial==3.5
|
||||
# via
|
||||
# dynamixel-sdk
|
||||
# feetech-servo-sdk
|
||||
# lerobot
|
||||
pytest==8.4.1
|
||||
# via
|
||||
# lerobot
|
||||
# pytest-cov
|
||||
# pytest-timeout
|
||||
pytest-cov==6.2.1
|
||||
# via lerobot
|
||||
pytest-timeout==2.4.0
|
||||
# via lerobot
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# matplotlib
|
||||
# pandas
|
||||
pytz==2025.2
|
||||
# via pandas
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# draccus
|
||||
# huggingface-hub
|
||||
# pre-commit
|
||||
# pyngrok
|
||||
# pyyaml-include
|
||||
# transformers
|
||||
# wandb
|
||||
pyyaml-include==1.4.1
|
||||
# via draccus
|
||||
pyzmq==27.0.0
|
||||
# via
|
||||
# lerobot
|
||||
# meshcat
|
||||
regex==2025.7.34
|
||||
# via
|
||||
# diffusers
|
||||
# transformers
|
||||
requests==2.32.4
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
# wandb
|
||||
rerun-sdk==0.22.1
|
||||
# via lerobot
|
||||
rhoban-cmeel-jsoncpp==1.9.4.9
|
||||
# via placo
|
||||
safetensors==0.5.3
|
||||
# via
|
||||
# accelerate
|
||||
# diffusers
|
||||
# lerobot
|
||||
# transformers
|
||||
scikit-image==0.25.2
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
scipy==1.15.3
|
||||
# via
|
||||
# dm-control
|
||||
# scikit-image
|
||||
sentry-sdk==2.34.1
|
||||
# via wandb
|
||||
shapely==2.1.1
|
||||
# via gym-pusht
|
||||
six==1.17.0
|
||||
# via
|
||||
# pynput
|
||||
# python-dateutil
|
||||
smmap==5.0.2
|
||||
# via gitdb
|
||||
stack-data==0.6.3
|
||||
# via ipython
|
||||
sympy==1.14.0
|
||||
# via torch
|
||||
termcolor==3.1.0
|
||||
# via lerobot
|
||||
tifffile==2025.5.10
|
||||
# via scikit-image
|
||||
tokenizers==0.21.4
|
||||
# via transformers
|
||||
toml==0.10.2
|
||||
# via draccus
|
||||
tomli==2.2.1
|
||||
# via
|
||||
# cmeel
|
||||
# coverage
|
||||
# pytest
|
||||
torch==2.7.1
|
||||
# via
|
||||
# accelerate
|
||||
# lerobot
|
||||
# torchvision
|
||||
torchcodec==0.5
|
||||
# via lerobot
|
||||
torchvision==0.22.1
|
||||
# via lerobot
|
||||
tornado==6.5.1
|
||||
# via meshcat
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# datasets
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
traitlets==5.14.3
|
||||
# via
|
||||
# ipython
|
||||
# matplotlib-inline
|
||||
transformers==4.51.3
|
||||
# via lerobot
|
||||
typing-extensions==4.14.1
|
||||
# via
|
||||
# aiosignal
|
||||
# exceptiongroup
|
||||
# gymnasium
|
||||
# huggingface-hub
|
||||
# ipython
|
||||
# multidict
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# rerun-sdk
|
||||
# torch
|
||||
# typing-inspect
|
||||
# typing-inspection
|
||||
# wandb
|
||||
typing-inspect==0.9.0
|
||||
# via draccus
|
||||
typing-inspection==0.4.1
|
||||
# via pydantic
|
||||
tzdata==2025.2
|
||||
# via pandas
|
||||
u-msgpack-python==2.8.0
|
||||
# via meshcat
|
||||
urllib3==2.5.0
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
virtualenv==20.32.0
|
||||
# via pre-commit
|
||||
wandb==0.21.0
|
||||
# via lerobot
|
||||
wcwidth==0.2.13
|
||||
# via prompt-toolkit
|
||||
werkzeug==3.1.3
|
||||
# via flask
|
||||
wrapt==1.17.2
|
||||
# via dm-tree
|
||||
xxhash==3.5.0
|
||||
# via datasets
|
||||
yarl==1.20.1
|
||||
# via aiohttp
|
||||
zipp==3.23.0
|
||||
# via importlib-metadata
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
# setuptools
|
||||
@@ -0,0 +1,650 @@
|
||||
#
|
||||
# This file is autogenerated by pip-compile with Python 3.10
|
||||
# by the following command:
|
||||
#
|
||||
# pip-compile --output-file=requirements-ubuntu.txt requirements.in
|
||||
#
|
||||
-e .[all]
|
||||
# via -[all]
|
||||
absl-py==2.3.1
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
# dm-tree
|
||||
# labmaze
|
||||
# mujoco
|
||||
accelerate==1.9.0
|
||||
# via lerobot
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.12.15
|
||||
# via fsspec
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
asttokens==3.0.0
|
||||
# via stack-data
|
||||
async-timeout==5.0.1
|
||||
# via aiohttp
|
||||
attrs==25.3.0
|
||||
# via
|
||||
# aiohttp
|
||||
# dm-tree
|
||||
# jsonlines
|
||||
# rerun-sdk
|
||||
av==15.0.0
|
||||
# via lerobot
|
||||
blinker==1.9.0
|
||||
# via flask
|
||||
certifi==2025.7.14
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
cffi==1.17.1
|
||||
# via pymunk
|
||||
cfgv==3.4.0
|
||||
# via pre-commit
|
||||
charset-normalizer==3.4.2
|
||||
# via requests
|
||||
click==8.2.1
|
||||
# via
|
||||
# flask
|
||||
# wandb
|
||||
cloudpickle==3.1.1
|
||||
# via gymnasium
|
||||
cmake==4.0.3
|
||||
# via lerobot
|
||||
cmeel==0.57.3
|
||||
# via
|
||||
# cmeel-assimp
|
||||
# cmeel-boost
|
||||
# cmeel-console-bridge
|
||||
# cmeel-octomap
|
||||
# cmeel-qhull
|
||||
# cmeel-tinyxml2
|
||||
# cmeel-urdfdom
|
||||
# cmeel-zlib
|
||||
# coal-library
|
||||
# eigenpy
|
||||
# eiquadprog
|
||||
# pin
|
||||
# placo
|
||||
# rhoban-cmeel-jsoncpp
|
||||
cmeel-assimp==5.4.3.1
|
||||
# via coal-library
|
||||
cmeel-boost==1.87.0.1
|
||||
# via
|
||||
# coal-library
|
||||
# eigenpy
|
||||
# eiquadprog
|
||||
# pin
|
||||
cmeel-console-bridge==1.0.2.3
|
||||
# via cmeel-urdfdom
|
||||
cmeel-octomap==1.10.0
|
||||
# via coal-library
|
||||
cmeel-qhull==8.0.2.1
|
||||
# via coal-library
|
||||
cmeel-tinyxml2==10.0.0
|
||||
# via cmeel-urdfdom
|
||||
cmeel-urdfdom==4.0.1
|
||||
# via pin
|
||||
cmeel-zlib==1.3.1
|
||||
# via cmeel-assimp
|
||||
coal-library==3.0.1
|
||||
# via pin
|
||||
contourpy==1.3.2
|
||||
# via matplotlib
|
||||
coverage[toml]==7.10.1
|
||||
# via pytest-cov
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
datasets==3.6.0
|
||||
# via lerobot
|
||||
debugpy==1.8.15
|
||||
# via lerobot
|
||||
decorator==5.2.1
|
||||
# via ipython
|
||||
deepdiff==8.5.0
|
||||
# via lerobot
|
||||
diffusers==0.34.0
|
||||
# via lerobot
|
||||
dill==0.3.8
|
||||
# via
|
||||
# datasets
|
||||
# multiprocess
|
||||
distlib==0.4.0
|
||||
# via virtualenv
|
||||
dm-control==1.0.14
|
||||
# via gym-aloha
|
||||
dm-env==1.6
|
||||
# via dm-control
|
||||
dm-tree==0.1.9
|
||||
# via
|
||||
# dm-control
|
||||
# dm-env
|
||||
docopt==0.6.2
|
||||
# via num2words
|
||||
draccus==0.10.0
|
||||
# via lerobot
|
||||
dynamixel-sdk==3.7.31
|
||||
# via lerobot
|
||||
eigenpy==3.10.3
|
||||
# via coal-library
|
||||
einops==0.8.1
|
||||
# via lerobot
|
||||
eiquadprog==1.2.9
|
||||
# via placo
|
||||
evdev==1.9.2
|
||||
# via pynput
|
||||
exceptiongroup==1.3.0
|
||||
# via
|
||||
# ipython
|
||||
# pytest
|
||||
executing==2.2.0
|
||||
# via stack-data
|
||||
farama-notifications==0.0.4
|
||||
# via gymnasium
|
||||
feetech-servo-sdk==1.0.0
|
||||
# via lerobot
|
||||
filelock==3.18.0
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
# huggingface-hub
|
||||
# torch
|
||||
# transformers
|
||||
# virtualenv
|
||||
flask==3.1.1
|
||||
# via lerobot
|
||||
fonttools==4.59.0
|
||||
# via matplotlib
|
||||
frozenlist==1.7.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
fsspec[http]==2025.3.0
|
||||
# via
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# torch
|
||||
gitdb==4.0.12
|
||||
# via gitpython
|
||||
gitpython==3.1.45
|
||||
# via wandb
|
||||
glfw==2.9.0
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
grpcio==1.73.1
|
||||
# via
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
grpcio-tools==1.73.1
|
||||
# via lerobot
|
||||
gym-aloha==0.1.1
|
||||
# via lerobot
|
||||
gym-hil==0.1.10
|
||||
# via lerobot
|
||||
gym-pusht==0.1.5
|
||||
# via lerobot
|
||||
gym-xarm==0.1.1
|
||||
# via lerobot
|
||||
gymnasium==0.29.1
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-pusht
|
||||
# gym-xarm
|
||||
# gymnasium-robotics
|
||||
# lerobot
|
||||
# pettingzoo
|
||||
gymnasium-robotics==1.2.4
|
||||
# via gym-xarm
|
||||
hf-transfer==0.1.9
|
||||
# via huggingface-hub
|
||||
hf-xet==1.1.5
|
||||
# via huggingface-hub
|
||||
hidapi==0.14.0.post4
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
huggingface-hub[cli,hf-transfer]==0.34.3
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# diffusers
|
||||
# lerobot
|
||||
# tokenizers
|
||||
# transformers
|
||||
identify==2.6.12
|
||||
# via pre-commit
|
||||
idna==3.10
|
||||
# via
|
||||
# requests
|
||||
# yarl
|
||||
imageio[ffmpeg]==2.37.0
|
||||
# via
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gymnasium-robotics
|
||||
# lerobot
|
||||
# scikit-image
|
||||
imageio-ffmpeg==0.6.0
|
||||
# via imageio
|
||||
importlib-metadata==8.7.0
|
||||
# via diffusers
|
||||
iniconfig==2.1.0
|
||||
# via pytest
|
||||
inquirerpy==0.3.4
|
||||
# via huggingface-hub
|
||||
ipython==8.37.0
|
||||
# via meshcat
|
||||
ischedule==1.2.7
|
||||
# via placo
|
||||
itsdangerous==2.2.0
|
||||
# via flask
|
||||
jedi==0.19.2
|
||||
# via ipython
|
||||
jinja2==3.1.6
|
||||
# via
|
||||
# flask
|
||||
# gymnasium-robotics
|
||||
# torch
|
||||
jsonlines==4.0.0
|
||||
# via lerobot
|
||||
kiwisolver==1.4.8
|
||||
# via matplotlib
|
||||
labmaze==1.0.6
|
||||
# via dm-control
|
||||
lazy-loader==0.4
|
||||
# via scikit-image
|
||||
lxml==6.0.0
|
||||
# via dm-control
|
||||
markupsafe==3.0.2
|
||||
# via
|
||||
# flask
|
||||
# jinja2
|
||||
# werkzeug
|
||||
matplotlib==3.10.5
|
||||
# via lerobot
|
||||
matplotlib-inline==0.1.7
|
||||
# via ipython
|
||||
mergedeep==1.3.4
|
||||
# via draccus
|
||||
meshcat==0.3.2
|
||||
# via placo
|
||||
mock-serial==0.0.1
|
||||
# via lerobot
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
mujoco==2.3.7
|
||||
# via
|
||||
# dm-control
|
||||
# gym-aloha
|
||||
# gym-hil
|
||||
# gym-xarm
|
||||
# gymnasium-robotics
|
||||
multidict==6.6.3
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
multiprocess==0.70.16
|
||||
# via datasets
|
||||
mypy-extensions==1.1.0
|
||||
# via typing-inspect
|
||||
networkx==3.4.2
|
||||
# via
|
||||
# scikit-image
|
||||
# torch
|
||||
nodeenv==1.9.1
|
||||
# via pre-commit
|
||||
num2words==0.5.14
|
||||
# via lerobot
|
||||
numpy==2.2.6
|
||||
# via
|
||||
# accelerate
|
||||
# cmeel-boost
|
||||
# contourpy
|
||||
# datasets
|
||||
# diffusers
|
||||
# dm-control
|
||||
# dm-env
|
||||
# dm-tree
|
||||
# gymnasium
|
||||
# gymnasium-robotics
|
||||
# imageio
|
||||
# labmaze
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# mujoco
|
||||
# opencv-python
|
||||
# opencv-python-headless
|
||||
# pandas
|
||||
# pettingzoo
|
||||
# rerun-sdk
|
||||
# scikit-image
|
||||
# scipy
|
||||
# shapely
|
||||
# tifffile
|
||||
# torchvision
|
||||
# transformers
|
||||
nvidia-cublas-cu12==12.6.4.1
|
||||
# via
|
||||
# nvidia-cudnn-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
# torch
|
||||
nvidia-cuda-cupti-cu12==12.6.80
|
||||
# via torch
|
||||
nvidia-cuda-nvrtc-cu12==12.6.77
|
||||
# via torch
|
||||
nvidia-cuda-runtime-cu12==12.6.77
|
||||
# via torch
|
||||
nvidia-cudnn-cu12==9.5.1.17
|
||||
# via torch
|
||||
nvidia-cufft-cu12==11.3.0.4
|
||||
# via torch
|
||||
nvidia-cufile-cu12==1.11.1.6
|
||||
# via torch
|
||||
nvidia-curand-cu12==10.3.7.77
|
||||
# via torch
|
||||
nvidia-cusolver-cu12==11.7.1.2
|
||||
# via torch
|
||||
nvidia-cusparse-cu12==12.5.4.2
|
||||
# via
|
||||
# nvidia-cusolver-cu12
|
||||
# torch
|
||||
nvidia-cusparselt-cu12==0.6.3
|
||||
# via torch
|
||||
nvidia-nccl-cu12==2.26.2
|
||||
# via torch
|
||||
nvidia-nvjitlink-cu12==12.6.85
|
||||
# via
|
||||
# nvidia-cufft-cu12
|
||||
# nvidia-cusolver-cu12
|
||||
# nvidia-cusparse-cu12
|
||||
# torch
|
||||
nvidia-nvtx-cu12==12.6.77
|
||||
# via torch
|
||||
opencv-python==4.12.0.88
|
||||
# via gym-pusht
|
||||
opencv-python-headless==4.12.0.88
|
||||
# via lerobot
|
||||
orderly-set==5.5.0
|
||||
# via deepdiff
|
||||
packaging==25.0
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# huggingface-hub
|
||||
# lazy-loader
|
||||
# lerobot
|
||||
# matplotlib
|
||||
# pytest
|
||||
# scikit-image
|
||||
# transformers
|
||||
# wandb
|
||||
pandas==2.3.1
|
||||
# via
|
||||
# datasets
|
||||
# lerobot
|
||||
parso==0.8.4
|
||||
# via jedi
|
||||
pettingzoo==1.24.3
|
||||
# via gymnasium-robotics
|
||||
pexpect==4.9.0
|
||||
# via ipython
|
||||
pfzy==0.3.4
|
||||
# via inquirerpy
|
||||
pillow==11.3.0
|
||||
# via
|
||||
# diffusers
|
||||
# imageio
|
||||
# matplotlib
|
||||
# meshcat
|
||||
# rerun-sdk
|
||||
# scikit-image
|
||||
# torchvision
|
||||
pin==3.4.0
|
||||
# via placo
|
||||
placo==0.9.14
|
||||
# via lerobot
|
||||
platformdirs==4.3.8
|
||||
# via
|
||||
# virtualenv
|
||||
# wandb
|
||||
pluggy==1.6.0
|
||||
# via
|
||||
# pytest
|
||||
# pytest-cov
|
||||
pre-commit==4.2.0
|
||||
# via lerobot
|
||||
prompt-toolkit==3.0.51
|
||||
# via
|
||||
# inquirerpy
|
||||
# ipython
|
||||
propcache==0.3.2
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
protobuf==6.31.0
|
||||
# via
|
||||
# dm-control
|
||||
# grpcio-tools
|
||||
# lerobot
|
||||
# wandb
|
||||
psutil==7.0.0
|
||||
# via
|
||||
# accelerate
|
||||
# imageio
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pure-eval==0.2.3
|
||||
# via stack-data
|
||||
pyarrow==21.0.0
|
||||
# via
|
||||
# datasets
|
||||
# rerun-sdk
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pydantic==2.11.7
|
||||
# via wandb
|
||||
pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pygame==2.6.1
|
||||
# via
|
||||
# gym-hil
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
pygments==2.19.2
|
||||
# via
|
||||
# ipython
|
||||
# pytest
|
||||
pymunk==6.11.1
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
pyngrok==7.2.12
|
||||
# via meshcat
|
||||
pynput==1.8.1
|
||||
# via
|
||||
# gym-hil
|
||||
# lerobot
|
||||
pyopengl==3.1.9
|
||||
# via
|
||||
# dm-control
|
||||
# mujoco
|
||||
pyparsing==3.2.3
|
||||
# via
|
||||
# dm-control
|
||||
# matplotlib
|
||||
pyrealsense2==2.56.5.9235
|
||||
# via lerobot
|
||||
pyserial==3.5
|
||||
# via
|
||||
# dynamixel-sdk
|
||||
# feetech-servo-sdk
|
||||
# lerobot
|
||||
pytest==8.4.1
|
||||
# via
|
||||
# lerobot
|
||||
# pytest-cov
|
||||
# pytest-timeout
|
||||
pytest-cov==6.2.1
|
||||
# via lerobot
|
||||
pytest-timeout==2.4.0
|
||||
# via lerobot
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# matplotlib
|
||||
# pandas
|
||||
python-xlib==0.33
|
||||
# via pynput
|
||||
pytz==2025.2
|
||||
# via pandas
|
||||
pyyaml==6.0.2
|
||||
# via
|
||||
# accelerate
|
||||
# datasets
|
||||
# draccus
|
||||
# huggingface-hub
|
||||
# pre-commit
|
||||
# pyngrok
|
||||
# pyyaml-include
|
||||
# transformers
|
||||
# wandb
|
||||
pyyaml-include==1.4.1
|
||||
# via draccus
|
||||
pyzmq==27.0.0
|
||||
# via
|
||||
# lerobot
|
||||
# meshcat
|
||||
regex==2025.7.34
|
||||
# via
|
||||
# diffusers
|
||||
# transformers
|
||||
requests==2.32.4
|
||||
# via
|
||||
# datasets
|
||||
# diffusers
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
# wandb
|
||||
rerun-sdk==0.22.1
|
||||
# via lerobot
|
||||
rhoban-cmeel-jsoncpp==1.9.4.9
|
||||
# via placo
|
||||
safetensors==0.5.3
|
||||
# via
|
||||
# accelerate
|
||||
# diffusers
|
||||
# lerobot
|
||||
# transformers
|
||||
scikit-image==0.25.2
|
||||
# via
|
||||
# gym-pusht
|
||||
# lerobot
|
||||
scipy==1.15.3
|
||||
# via
|
||||
# dm-control
|
||||
# scikit-image
|
||||
sentry-sdk==2.34.1
|
||||
# via wandb
|
||||
shapely==2.1.1
|
||||
# via gym-pusht
|
||||
six==1.17.0
|
||||
# via
|
||||
# pynput
|
||||
# python-dateutil
|
||||
# python-xlib
|
||||
smmap==5.0.2
|
||||
# via gitdb
|
||||
stack-data==0.6.3
|
||||
# via ipython
|
||||
sympy==1.14.0
|
||||
# via torch
|
||||
termcolor==3.1.0
|
||||
# via lerobot
|
||||
tifffile==2025.5.10
|
||||
# via scikit-image
|
||||
tokenizers==0.21.4
|
||||
# via transformers
|
||||
toml==0.10.2
|
||||
# via draccus
|
||||
tomli==2.2.1
|
||||
# via
|
||||
# cmeel
|
||||
# coverage
|
||||
# pytest
|
||||
torch==2.7.1
|
||||
# via
|
||||
# accelerate
|
||||
# lerobot
|
||||
# torchvision
|
||||
torchcodec==0.5
|
||||
# via lerobot
|
||||
torchvision==0.22.1
|
||||
# via lerobot
|
||||
tornado==6.5.1
|
||||
# via meshcat
|
||||
tqdm==4.67.1
|
||||
# via
|
||||
# datasets
|
||||
# dm-control
|
||||
# huggingface-hub
|
||||
# transformers
|
||||
traitlets==5.14.3
|
||||
# via
|
||||
# ipython
|
||||
# matplotlib-inline
|
||||
transformers==4.51.3
|
||||
# via lerobot
|
||||
triton==3.3.1
|
||||
# via torch
|
||||
typing-extensions==4.14.1
|
||||
# via
|
||||
# aiosignal
|
||||
# exceptiongroup
|
||||
# gymnasium
|
||||
# huggingface-hub
|
||||
# ipython
|
||||
# multidict
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# rerun-sdk
|
||||
# torch
|
||||
# typing-inspect
|
||||
# typing-inspection
|
||||
# wandb
|
||||
typing-inspect==0.9.0
|
||||
# via draccus
|
||||
typing-inspection==0.4.1
|
||||
# via pydantic
|
||||
tzdata==2025.2
|
||||
# via pandas
|
||||
u-msgpack-python==2.8.0
|
||||
# via meshcat
|
||||
urllib3==2.5.0
|
||||
# via
|
||||
# requests
|
||||
# sentry-sdk
|
||||
virtualenv==20.32.0
|
||||
# via pre-commit
|
||||
wandb==0.21.0
|
||||
# via lerobot
|
||||
wcwidth==0.2.13
|
||||
# via prompt-toolkit
|
||||
werkzeug==3.1.3
|
||||
# via flask
|
||||
wrapt==1.17.2
|
||||
# via dm-tree
|
||||
xxhash==3.5.0
|
||||
# via datasets
|
||||
yarl==1.20.1
|
||||
# via aiohttp
|
||||
zipp==3.23.0
|
||||
# via importlib-metadata
|
||||
|
||||
# The following packages are considered to be unsafe in a requirements file:
|
||||
# setuptools
|
||||
@@ -0,0 +1,9 @@
|
||||
# requirements.in
|
||||
|
||||
# requirements-macos.txt was generated on macOS and is platform-specific (macOS 15.5 24F74 arm64).
|
||||
# Darwin MacBook-Pro.local 24.5.0 Darwin Kernel Version 24.5.0: Tue Apr 22 19:54:43 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T8132 arm64
|
||||
|
||||
# requirements-ubuntu.txt was generated on Linux and is platform-specific (Ubuntu 24.04.2 LTS x86_64).
|
||||
# Linux mlerobot-linux 6.14.0-27-generic #27~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Tue Jul 22 17:38:49 UTC 2 x86_64 x86_64 x86_64 GNU/Linux
|
||||
|
||||
-e .[all]
|
||||
@@ -82,5 +82,9 @@ def calibrate(cfg: CalibrateConfig):
|
||||
device.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
calibrate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -368,7 +368,7 @@ class OpenCVCamera(Camera):
|
||||
if requested_color_mode == ColorMode.RGB:
|
||||
processed_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]:
|
||||
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]:
|
||||
processed_image = cv2.rotate(processed_image, self.rotation)
|
||||
|
||||
return processed_image
|
||||
|
||||
@@ -434,7 +434,7 @@ class RealSenseCamera(Camera):
|
||||
if self.color_mode == ColorMode.BGR:
|
||||
processed_image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
||||
|
||||
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE]:
|
||||
if self.rotation in [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_90_COUNTERCLOCKWISE, cv2.ROTATE_180]:
|
||||
processed_image = cv2.rotate(processed_image, self.rotation)
|
||||
|
||||
return processed_image
|
||||
|
||||
@@ -27,6 +27,7 @@ from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.optim.optimizers import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.hub import HubMixin
|
||||
@@ -119,8 +120,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
|
||||
@property
|
||||
def robot_state_feature(self) -> PolicyFeature | None:
|
||||
for _, ft in self.input_features.items():
|
||||
if ft.type is FeatureType.STATE:
|
||||
for ft_name, ft in self.input_features.items():
|
||||
if ft.type is FeatureType.STATE and ft_name == OBS_STATE:
|
||||
return ft
|
||||
return None
|
||||
|
||||
@@ -137,8 +138,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
|
||||
@property
|
||||
def action_feature(self) -> PolicyFeature | None:
|
||||
for _, ft in self.output_features.items():
|
||||
if ft.type is FeatureType.ACTION:
|
||||
for ft_name, ft in self.output_features.items():
|
||||
if ft.type is FeatureType.ACTION and ft_name == ACTION:
|
||||
return ft
|
||||
return None
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
@EnvConfig.register_subclass("aloha")
|
||||
@dataclass
|
||||
class AlohaEnv(EnvConfig):
|
||||
task: str = "AlohaInsertion-v0"
|
||||
task: str | None = "AlohaInsertion-v0"
|
||||
fps: int = 50
|
||||
episode_length: int = 400
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
@@ -82,7 +82,7 @@ class AlohaEnv(EnvConfig):
|
||||
@EnvConfig.register_subclass("pusht")
|
||||
@dataclass
|
||||
class PushtEnv(EnvConfig):
|
||||
task: str = "PushT-v0"
|
||||
task: str | None = "PushT-v0"
|
||||
fps: int = 10
|
||||
episode_length: int = 300
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
@@ -124,7 +124,7 @@ class PushtEnv(EnvConfig):
|
||||
@EnvConfig.register_subclass("xarm")
|
||||
@dataclass
|
||||
class XarmEnv(EnvConfig):
|
||||
task: str = "XarmLift-v0"
|
||||
task: str | None = "XarmLift-v0"
|
||||
fps: int = 15
|
||||
episode_length: int = 200
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
@@ -200,10 +200,10 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
||||
wrapper: EnvTransformConfig | None = None
|
||||
fps: int = 10
|
||||
name: str = "real_robot"
|
||||
mode: str = None # Either "record", "replay", None
|
||||
mode: str | None = None # Either "record", "replay", None
|
||||
repo_id: str | None = None
|
||||
dataset_root: str | None = None
|
||||
task: str = ""
|
||||
task: str | None = ""
|
||||
num_episodes: int = 10 # only for record mode
|
||||
episode: int = 0
|
||||
device: str = "cuda"
|
||||
@@ -213,6 +213,7 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
||||
# For the reward classifier, to record more positive examples after a success
|
||||
number_of_steps_after_success: int = 0
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {}
|
||||
|
||||
@@ -222,9 +223,8 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
||||
class HILEnvConfig(EnvConfig):
|
||||
"""Configuration for the HIL environment."""
|
||||
|
||||
type: str = "hil"
|
||||
name: str = "PandaPickCube"
|
||||
task: str = "PandaPickCubeKeyboard-v0"
|
||||
task: str | None = "PandaPickCubeKeyboard-v0"
|
||||
use_viewer: bool = True
|
||||
gripper_penalty: float = 0.0
|
||||
use_gamepad: bool = True
|
||||
@@ -252,7 +252,7 @@ class HILEnvConfig(EnvConfig):
|
||||
robot_config: RobotConfig | None = None
|
||||
teleop_config: TeleoperatorConfig | None = None
|
||||
wrapper: EnvTransformConfig | None = None
|
||||
mode: str = None # Either "record", "replay", None
|
||||
mode: str | None = None # Either "record", "replay", None
|
||||
repo_id: str | None = None
|
||||
dataset_root: str | None = None
|
||||
num_episodes: int = 10 # only for record mode
|
||||
|
||||
@@ -286,7 +286,7 @@ def save_images_from_all_cameras(
|
||||
print(f"Image capture finished. Images saved to {output_dir}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Unified camera utility script for listing cameras and capturing images."
|
||||
)
|
||||
@@ -313,3 +313,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
save_images_from_all_cameras(**vars(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -61,5 +61,9 @@ def find_port():
|
||||
raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
find_port()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
+1
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_act_README.md
|
||||
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_diffusion_README.md
|
||||
@@ -217,12 +217,13 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
)
|
||||
|
||||
# Check that all input images have the same shape.
|
||||
first_image_key, first_image_ft = next(iter(self.image_features.items()))
|
||||
for key, image_ft in self.image_features.items():
|
||||
if image_ft.shape != first_image_ft.shape:
|
||||
raise ValueError(
|
||||
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
|
||||
)
|
||||
if len(self.image_features) > 0:
|
||||
first_image_key, first_image_ft = next(iter(self.image_features.items()))
|
||||
for key, image_ft in self.image_features.items():
|
||||
if image_ft.shape != first_image_ft.shape:
|
||||
raise ValueError(
|
||||
f"`{key}` does not match `{first_image_key}`, but we expect all image shapes to match."
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list:
|
||||
|
||||
@@ -288,7 +288,7 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
||||
AND/OR
|
||||
"observation.environment_state": (B, environment_dim)
|
||||
"observation.environment_state": (B, n_obs_steps, environment_dim)
|
||||
}
|
||||
"""
|
||||
batch_size, n_obs_steps = batch["observation.state"].shape[:2]
|
||||
@@ -315,7 +315,7 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
"observation.images": (B, n_obs_steps, num_cameras, C, H, W)
|
||||
AND/OR
|
||||
"observation.environment_state": (B, environment_dim)
|
||||
"observation.environment_state": (B, n_obs_steps, environment_dim)
|
||||
|
||||
"action": (B, horizon, action_dim)
|
||||
"action_is_pad": (B, horizon)
|
||||
|
||||
@@ -66,7 +66,8 @@ from lerobot.policies.pi0.paligemma_with_expert import (
|
||||
PaliGemmaWithExpertModel,
|
||||
)
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
from lerobot.policies.utils import log_model_loading_keys
|
||||
from lerobot.utils.utils import get_safe_dtype, init_logging
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
@@ -90,12 +91,6 @@ def create_sinusoidal_pos_embedding(
|
||||
return pos_emb
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device):
|
||||
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
|
||||
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
|
||||
return gamma1 / (gamma1 + gamma2)
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks):
|
||||
"""Copied from big_vision.
|
||||
|
||||
@@ -258,6 +253,99 @@ class PI0Policy(PreTrainedPolicy):
|
||||
"""This should be called whenever the environment is reset."""
|
||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||
|
||||
@classmethod
|
||||
def _transform_state_dict_keys(cls, state_dict: dict) -> dict:
|
||||
"""
|
||||
Transform state dict keys to match expected model structure.
|
||||
|
||||
Transformations:
|
||||
- model.paligemma_with_expert.paligemma.language_model.lm_head ->
|
||||
model.paligemma_with_expert.paligemma.lm_head
|
||||
- model.paligemma_with_expert.paligemma.language_model.model ->
|
||||
model.paligemma_with_expert.paligemma.model.language_model
|
||||
- model.paligemma_with_expert.paligemma.vision_tower ->
|
||||
model.paligemma_with_expert.paligemma.model.vision_tower
|
||||
- model.paligemma_with_expert.paligemma.multi_modal_projector ->
|
||||
model.paligemma_with_expert.paligemma.model.multi_modal_projector
|
||||
|
||||
Also handles tied weights between lm_head.weight and
|
||||
embed_tokens.weight.
|
||||
"""
|
||||
import re
|
||||
|
||||
transformed_dict = {}
|
||||
|
||||
transformations = [
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.lm_head"),
|
||||
".paligemma_with_expert.paligemma.lm_head",
|
||||
),
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.language_model\.model"),
|
||||
".paligemma_with_expert.paligemma.model.language_model",
|
||||
),
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.vision_tower"),
|
||||
".paligemma_with_expert.paligemma.model.vision_tower",
|
||||
),
|
||||
(
|
||||
re.compile(r"\.paligemma_with_expert\.paligemma\.multi_modal_projector"),
|
||||
".paligemma_with_expert.paligemma.model.multi_modal_projector",
|
||||
),
|
||||
]
|
||||
|
||||
for key, value in state_dict.items():
|
||||
new_key = key
|
||||
for pattern, replacement in transformations:
|
||||
new_key = pattern.sub(replacement, new_key)
|
||||
transformed_dict[new_key] = value
|
||||
|
||||
# Handle tied weights: lm_head.weight and embed_tokens.weight share memory
|
||||
lm_head_key = None
|
||||
embed_tokens_key = None
|
||||
|
||||
for key in transformed_dict:
|
||||
if key.endswith(".paligemma_with_expert.paligemma.lm_head.weight"):
|
||||
lm_head_key = key
|
||||
elif key.endswith(".paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"):
|
||||
embed_tokens_key = key
|
||||
if lm_head_key and embed_tokens_key:
|
||||
break
|
||||
|
||||
if lm_head_key and not embed_tokens_key:
|
||||
embed_tokens_key = lm_head_key.replace(
|
||||
".lm_head.weight", ".model.language_model.embed_tokens.weight"
|
||||
)
|
||||
transformed_dict[embed_tokens_key] = transformed_dict[lm_head_key]
|
||||
elif embed_tokens_key and not lm_head_key:
|
||||
lm_head_key = embed_tokens_key.replace(
|
||||
".model.language_model.embed_tokens.weight", ".lm_head.weight"
|
||||
)
|
||||
transformed_dict[lm_head_key] = transformed_dict[embed_tokens_key]
|
||||
|
||||
return transformed_dict
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(
|
||||
cls, model: "PI0Policy", model_file: str, map_location: str, strict: bool
|
||||
) -> "PI0Policy":
|
||||
"""Override to apply key transformations before loading."""
|
||||
from safetensors.torch import load_file
|
||||
|
||||
init_logging()
|
||||
# Load the state dict from file safely
|
||||
state_dict = load_file(model_file, device=map_location)
|
||||
|
||||
# Apply key transformations
|
||||
transformed_state_dict = cls._transform_state_dict_keys(state_dict)
|
||||
|
||||
# Load the transformed state dict
|
||||
msg = model.load_state_dict(transformed_state_dict, strict=strict)
|
||||
|
||||
# Log message
|
||||
log_model_loading_keys(msg.missing_keys, msg.unexpected_keys)
|
||||
return model
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
@@ -515,9 +603,10 @@ class PI0FlowMatching(nn.Module):
|
||||
return noise
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
||||
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
|
||||
time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32)
|
||||
time = time_beta * 0.999 + 0.001
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
return time
|
||||
|
||||
def embed_prefix(
|
||||
self, images, img_masks, lang_tokens, lang_masks
|
||||
|
||||
@@ -30,6 +30,7 @@ from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.policies.utils import log_model_loading_keys
|
||||
from lerobot.utils.hub import HubMixin
|
||||
|
||||
T = TypeVar("T", bound="PreTrainedPolicy")
|
||||
@@ -128,18 +129,26 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
|
||||
@classmethod
|
||||
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
|
||||
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
|
||||
load_model_as_safetensor(model, model_file, strict=strict)
|
||||
if map_location != "cpu":
|
||||
logging.warning(
|
||||
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
|
||||
" This means that the model is loaded on 'cpu' first and then copied to the device."
|
||||
" This leads to a slower loading time."
|
||||
" Please update safetensors to version 0.4.3 or above for improved performance."
|
||||
)
|
||||
model.to(map_location)
|
||||
else:
|
||||
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
|
||||
# Create base kwargs
|
||||
kwargs = {"strict": strict}
|
||||
|
||||
# Add device parameter for newer versions that support it
|
||||
if packaging.version.parse(safetensors.__version__) >= packaging.version.parse("0.4.3"):
|
||||
kwargs["device"] = map_location
|
||||
|
||||
# Load the model with appropriate kwargs
|
||||
missing_keys, unexpected_keys = load_model_as_safetensor(model, model_file, **kwargs)
|
||||
log_model_loading_keys(missing_keys, unexpected_keys)
|
||||
|
||||
# For older versions, manually move to device if needed
|
||||
if "device" not in kwargs and map_location != "cpu":
|
||||
logging.warning(
|
||||
"Loading model weights on other devices than 'cpu' is not supported natively in your version of safetensors."
|
||||
" This means that the model is loaded on 'cpu' first and then copied to the device."
|
||||
" This leads to a slower loading time."
|
||||
" Please update safetensors to version 0.4.3 or above for improved performance."
|
||||
)
|
||||
model.to(map_location)
|
||||
return model
|
||||
|
||||
@abc.abstractmethod
|
||||
|
||||
+1
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_smolvla_README.md
|
||||
@@ -194,12 +194,6 @@ def create_sinusoidal_pos_embedding(
|
||||
return pos_emb
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device):
|
||||
gamma1 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / alpha)
|
||||
gamma2 = torch.empty((bsize,), device=device).uniform_(0, 1).pow(1 / beta)
|
||||
return gamma1 / (gamma1 + gamma2)
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks):
|
||||
"""Copied from big_vision.
|
||||
|
||||
@@ -690,9 +684,10 @@ class VLAFlowMatching(nn.Module):
|
||||
return noise
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
time_beta = sample_beta(1.5, 1.0, bsize, device)
|
||||
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
|
||||
time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32)
|
||||
time = time_beta * 0.999 + 0.001
|
||||
return time.to(dtype=torch.float32, device=device)
|
||||
return time
|
||||
|
||||
def embed_prefix(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state: torch.Tensor = None
|
||||
|
||||
+1
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_tdmpc_README.md
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
|
||||
import torch
|
||||
@@ -71,3 +72,16 @@ def get_output_shape(module: nn.Module, input_shape: tuple) -> tuple:
|
||||
with torch.inference_mode():
|
||||
output = module(dummy_input)
|
||||
return tuple(output.shape)
|
||||
|
||||
|
||||
def log_model_loading_keys(missing_keys: list[str], unexpected_keys: list[str]) -> None:
|
||||
"""Log missing and unexpected keys when loading a model.
|
||||
|
||||
Args:
|
||||
missing_keys (list[str]): Keys that were expected but not found.
|
||||
unexpected_keys (list[str]): Keys that were found but not expected.
|
||||
"""
|
||||
if missing_keys:
|
||||
logging.warning(f"Missing key(s) when loading model: {missing_keys}")
|
||||
if unexpected_keys:
|
||||
logging.warning(f"Unexpected key(s) when loading model: {unexpected_keys}")
|
||||
|
||||
+1
@@ -0,0 +1 @@
|
||||
../../../../docs/source/policy_vqbet_README.md
|
||||
@@ -0,0 +1,54 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .device_processor import DeviceProcessor
|
||||
from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor
|
||||
from .observation_processor import VanillaObservationProcessor
|
||||
from .pipeline import (
|
||||
ActionProcessor,
|
||||
DoneProcessor,
|
||||
EnvTransition,
|
||||
IdentityProcessor,
|
||||
InfoProcessor,
|
||||
ObservationProcessor,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RewardProcessor,
|
||||
RobotProcessor,
|
||||
TransitionKey,
|
||||
TruncatedProcessor,
|
||||
)
|
||||
from .rename_processor import RenameProcessor
|
||||
|
||||
__all__ = [
|
||||
"ActionProcessor",
|
||||
"DeviceProcessor",
|
||||
"DoneProcessor",
|
||||
"EnvTransition",
|
||||
"IdentityProcessor",
|
||||
"InfoProcessor",
|
||||
"NormalizerProcessor",
|
||||
"UnnormalizerProcessor",
|
||||
"ObservationProcessor",
|
||||
"ProcessorStep",
|
||||
"ProcessorStepRegistry",
|
||||
"RenameProcessor",
|
||||
"RewardProcessor",
|
||||
"RobotProcessor",
|
||||
"TransitionKey",
|
||||
"TruncatedProcessor",
|
||||
"VanillaObservationProcessor",
|
||||
]
|
||||
@@ -0,0 +1,82 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import EnvTransition, TransitionKey
|
||||
from lerobot.utils.utils import get_safe_torch_device
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceProcessor:
|
||||
"""Processes transitions by moving tensors to the specified device.
|
||||
|
||||
This processor ensures that all tensors in the transition are moved to the
|
||||
specified device (CPU or GPU) before they are returned.
|
||||
"""
|
||||
|
||||
device: torch.device = "cpu"
|
||||
|
||||
def __post_init__(self):
|
||||
self.device = get_safe_torch_device(self.device)
|
||||
self.non_blocking = "cuda" in str(self.device)
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
# Create a copy of the transition
|
||||
new_transition = transition.copy()
|
||||
|
||||
# Process observation tensors
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if observation is not None:
|
||||
new_observation = {
|
||||
k: v.to(self.device, non_blocking=self.non_blocking) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in observation.items()
|
||||
}
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
|
||||
# Process action tensor
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
if action is not None and isinstance(action, torch.Tensor):
|
||||
new_transition[TransitionKey.ACTION] = action.to(self.device, non_blocking=self.non_blocking)
|
||||
|
||||
# Process reward tensor
|
||||
reward = transition.get(TransitionKey.REWARD)
|
||||
if reward is not None and isinstance(reward, torch.Tensor):
|
||||
new_transition[TransitionKey.REWARD] = reward.to(self.device, non_blocking=self.non_blocking)
|
||||
|
||||
# Process done tensor
|
||||
done = transition.get(TransitionKey.DONE)
|
||||
if done is not None and isinstance(done, torch.Tensor):
|
||||
new_transition[TransitionKey.DONE] = done.to(self.device, non_blocking=self.non_blocking)
|
||||
|
||||
# Process truncated tensor
|
||||
truncated = transition.get(TransitionKey.TRUNCATED)
|
||||
if truncated is not None and isinstance(truncated, torch.Tensor):
|
||||
new_transition[TransitionKey.TRUNCATED] = truncated.to(
|
||||
self.device, non_blocking=self.non_blocking
|
||||
)
|
||||
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return configuration for serialization."""
|
||||
return {"device": self.device}
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -0,0 +1,331 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor.pipeline import EnvTransition, ProcessorStepRegistry, TransitionKey
|
||||
|
||||
|
||||
def _convert_stats_to_tensors(stats: dict[str, dict[str, Any]]) -> dict[str, dict[str, Tensor]]:
|
||||
"""Convert numpy arrays and other types to torch tensors."""
|
||||
tensor_stats: dict[str, dict[str, Tensor]] = {}
|
||||
for key, sub in stats.items():
|
||||
tensor_stats[key] = {}
|
||||
for stat_name, value in sub.items():
|
||||
if isinstance(value, np.ndarray):
|
||||
tensor_val = torch.from_numpy(value.astype(np.float32))
|
||||
elif isinstance(value, torch.Tensor):
|
||||
tensor_val = value.to(dtype=torch.float32)
|
||||
elif isinstance(value, (int, float, list, tuple)):
|
||||
tensor_val = torch.tensor(value, dtype=torch.float32)
|
||||
else:
|
||||
raise TypeError(f"Unsupported type for stats['{key}']['{stat_name}']: {type(value)}")
|
||||
tensor_stats[key][stat_name] = tensor_val
|
||||
return tensor_stats
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="normalizer_processor")
|
||||
class NormalizerProcessor:
|
||||
"""Normalizes observations and actions in a single processor step.
|
||||
|
||||
This processor handles normalization of both observation and action tensors
|
||||
using either mean/std normalization or min/max scaling to a [-1, 1] range.
|
||||
|
||||
For each tensor key in the stats dictionary, the processor will:
|
||||
- Use mean/std normalization if those statistics are provided: (x - mean) / std
|
||||
- Use min/max scaling if those statistics are provided: 2 * (x - min) / (max - min) - 1
|
||||
|
||||
The processor can be configured to normalize only specific keys by setting
|
||||
the normalize_keys parameter.
|
||||
"""
|
||||
|
||||
# Features and normalisation map are mandatory to match the design of normalize.py
|
||||
features: dict[str, PolicyFeature]
|
||||
norm_map: dict[FeatureType, NormalizationMode]
|
||||
|
||||
# Pre-computed statistics coming from dataset.meta.stats for instance.
|
||||
stats: dict[str, dict[str, Any]] | None = None
|
||||
|
||||
# Explicit subset of keys to normalise. If ``None`` every key (except
|
||||
# "action") found in ``stats`` will be normalised. Using a ``set`` makes
|
||||
# membership checks O(1).
|
||||
normalize_keys: set[str] | None = None
|
||||
|
||||
eps: float = 1e-8
|
||||
|
||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
dataset: LeRobotDataset,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[FeatureType, NormalizationMode],
|
||||
*,
|
||||
normalize_keys: set[str] | None = None,
|
||||
eps: float = 1e-8,
|
||||
) -> NormalizerProcessor:
|
||||
"""Factory helper that pulls statistics from a :class:`LeRobotDataset`.
|
||||
|
||||
The features and norm_map parameters are mandatory to match the design
|
||||
pattern used in normalize.py.
|
||||
"""
|
||||
|
||||
return cls(
|
||||
features=features,
|
||||
norm_map=norm_map,
|
||||
stats=dataset.meta.stats,
|
||||
normalize_keys=normalize_keys,
|
||||
eps=eps,
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Handle deserialization from JSON config
|
||||
if self.features and isinstance(list(self.features.values())[0], dict):
|
||||
# Features came from JSON - need to reconstruct PolicyFeature objects
|
||||
reconstructed_features = {}
|
||||
for key, ft_dict in self.features.items():
|
||||
reconstructed_features[key] = PolicyFeature(
|
||||
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||
)
|
||||
self.features = reconstructed_features
|
||||
|
||||
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
|
||||
# norm_map came from JSON - need to reconstruct enum keys and values
|
||||
reconstructed_norm_map = {}
|
||||
for ft_type_str, norm_mode_str in self.norm_map.items():
|
||||
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
|
||||
self.norm_map = reconstructed_norm_map
|
||||
|
||||
# Convert statistics once so we avoid repeated numpy→Tensor conversions
|
||||
# during runtime.
|
||||
self.stats = self.stats or {}
|
||||
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
||||
|
||||
# Ensure *normalize_keys* is a set for fast look-ups and compare by
|
||||
# value later when returning the configuration.
|
||||
if self.normalize_keys is not None and not isinstance(self.normalize_keys, set):
|
||||
self.normalize_keys = set(self.normalize_keys)
|
||||
|
||||
def _normalize_obs(self, observation):
|
||||
if observation is None:
|
||||
return None
|
||||
|
||||
# Decide which keys should be normalised for this call.
|
||||
if self.normalize_keys is not None:
|
||||
keys_to_norm = self.normalize_keys
|
||||
else:
|
||||
# Use feature map to skip action keys.
|
||||
keys_to_norm = {k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION}
|
||||
|
||||
processed = dict(observation)
|
||||
for key in keys_to_norm:
|
||||
if key not in processed or key not in self._tensor_stats:
|
||||
continue
|
||||
|
||||
orig_val = processed[key]
|
||||
tensor = (
|
||||
orig_val.to(dtype=torch.float32)
|
||||
if isinstance(orig_val, torch.Tensor)
|
||||
else torch.as_tensor(orig_val, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
|
||||
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
processed[key] = (tensor - mean) / (std + self.eps)
|
||||
elif "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
processed[key] = 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||
return processed
|
||||
|
||||
def _normalize_action(self, action):
|
||||
if action is None or "action" not in self._tensor_stats:
|
||||
return action
|
||||
|
||||
tensor = (
|
||||
action.to(dtype=torch.float32)
|
||||
if isinstance(action, torch.Tensor)
|
||||
else torch.as_tensor(action, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
return (tensor - mean) / (std + self.eps)
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
return 2 * (tensor - min_val) / (max_val - min_val + self.eps) - 1
|
||||
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = self._normalize_obs(transition.get(TransitionKey.OBSERVATION))
|
||||
action = self._normalize_action(transition.get(TransitionKey.ACTION))
|
||||
|
||||
# Create a new transition with normalized values
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
config = {
|
||||
"eps": self.eps,
|
||||
"features": {
|
||||
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
|
||||
},
|
||||
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
|
||||
}
|
||||
if self.normalize_keys is not None:
|
||||
# Serialise as a list for YAML / JSON friendliness
|
||||
config["normalize_keys"] = sorted(self.normalize_keys)
|
||||
return config
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
flat = {}
|
||||
for key, sub in self._tensor_stats.items():
|
||||
for stat_name, tensor in sub.items():
|
||||
flat[f"{key}.{stat_name}"] = tensor
|
||||
return flat
|
||||
|
||||
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
|
||||
self._tensor_stats.clear()
|
||||
for flat_key, tensor in state.items():
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="unnormalizer_processor")
|
||||
class UnnormalizerProcessor:
|
||||
"""Inverse normalisation for observations and actions.
|
||||
|
||||
Exactly mirrors :class:`NormalizerProcessor` but applies the inverse
|
||||
transform.
|
||||
"""
|
||||
|
||||
features: dict[str, PolicyFeature]
|
||||
norm_map: dict[FeatureType, NormalizationMode]
|
||||
stats: dict[str, dict[str, Any]] | None = None
|
||||
|
||||
_tensor_stats: dict[str, dict[str, Tensor]] = field(default_factory=dict, init=False, repr=False)
|
||||
|
||||
@classmethod
|
||||
def from_lerobot_dataset(
|
||||
cls,
|
||||
dataset: LeRobotDataset,
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[FeatureType, NormalizationMode],
|
||||
) -> UnnormalizerProcessor:
|
||||
return cls(features=features, norm_map=norm_map, stats=dataset.meta.stats)
|
||||
|
||||
def __post_init__(self):
|
||||
# Handle deserialization from JSON config
|
||||
if self.features and isinstance(list(self.features.values())[0], dict):
|
||||
# Features came from JSON - need to reconstruct PolicyFeature objects
|
||||
reconstructed_features = {}
|
||||
for key, ft_dict in self.features.items():
|
||||
reconstructed_features[key] = PolicyFeature(
|
||||
type=FeatureType(ft_dict["type"]), shape=tuple(ft_dict["shape"])
|
||||
)
|
||||
self.features = reconstructed_features
|
||||
|
||||
if self.norm_map and isinstance(list(self.norm_map.keys())[0], str):
|
||||
# norm_map came from JSON - need to reconstruct enum keys and values
|
||||
reconstructed_norm_map = {}
|
||||
for ft_type_str, norm_mode_str in self.norm_map.items():
|
||||
reconstructed_norm_map[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
|
||||
self.norm_map = reconstructed_norm_map
|
||||
|
||||
self.stats = self.stats or {}
|
||||
self._tensor_stats = _convert_stats_to_tensors(self.stats)
|
||||
|
||||
def _unnormalize_obs(self, observation):
|
||||
if observation is None:
|
||||
return None
|
||||
keys = [k for k, ft in self.features.items() if ft.type is not FeatureType.ACTION]
|
||||
processed = dict(observation)
|
||||
for key in keys:
|
||||
if key not in processed or key not in self._tensor_stats:
|
||||
continue
|
||||
orig_val = processed[key]
|
||||
tensor = (
|
||||
orig_val.to(dtype=torch.float32)
|
||||
if isinstance(orig_val, torch.Tensor)
|
||||
else torch.as_tensor(orig_val, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats[key].items()}
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
processed[key] = tensor * std + mean
|
||||
elif "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
processed[key] = (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||
return processed
|
||||
|
||||
def _unnormalize_action(self, action):
|
||||
if action is None or "action" not in self._tensor_stats:
|
||||
return action
|
||||
tensor = (
|
||||
action.to(dtype=torch.float32)
|
||||
if isinstance(action, torch.Tensor)
|
||||
else torch.as_tensor(action, dtype=torch.float32)
|
||||
)
|
||||
stats = {k: v.to(tensor.device) for k, v in self._tensor_stats["action"].items()}
|
||||
if "mean" in stats and "std" in stats:
|
||||
mean, std = stats["mean"], stats["std"]
|
||||
return tensor * std + mean
|
||||
if "min" in stats and "max" in stats:
|
||||
min_val, max_val = stats["min"], stats["max"]
|
||||
return (tensor + 1) / 2 * (max_val - min_val) + min_val
|
||||
raise ValueError("Action stats must contain either ('mean','std') or ('min','max')")
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
observation = self._unnormalize_obs(transition.get(TransitionKey.OBSERVATION))
|
||||
action = self._unnormalize_action(transition.get(TransitionKey.ACTION))
|
||||
|
||||
# Create a new transition with unnormalized values
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"features": {
|
||||
key: {"type": ft.type.value, "shape": ft.shape} for key, ft in self.features.items()
|
||||
},
|
||||
"norm_map": {ft_type.value: norm_mode.value for ft_type, norm_mode in self.norm_map.items()},
|
||||
}
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
flat = {}
|
||||
for key, sub in self._tensor_stats.items():
|
||||
for stat_name, tensor in sub.items():
|
||||
flat[f"{key}.{stat_name}"] = tensor
|
||||
return flat
|
||||
|
||||
def load_state_dict(self, state: Mapping[str, Tensor]) -> None:
|
||||
self._tensor_stats.clear()
|
||||
for flat_key, tensor in state.items():
|
||||
key, stat_name = flat_key.rsplit(".", 1)
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
return features
|
||||
@@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor.pipeline import ObservationProcessor, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="observation_processor")
|
||||
class VanillaObservationProcessor(ObservationProcessor):
|
||||
"""
|
||||
Processes environment observations into the LeRobot format by handling both images and states.
|
||||
|
||||
Image processing:
|
||||
- Converts channel-last (H, W, C) images to channel-first (C, H, W)
|
||||
- Normalizes uint8 images ([0, 255]) to float32 ([0, 1])
|
||||
- Adds a batch dimension if missing
|
||||
- Supports single images and image dictionaries
|
||||
|
||||
State processing:
|
||||
- Maps 'environment_state' to observation.environment_state
|
||||
- Maps 'agent_pos' to observation.state
|
||||
- Converts numpy arrays to tensors
|
||||
- Adds a batch dimension if missing
|
||||
"""
|
||||
|
||||
def _process_single_image(self, img: np.ndarray) -> Tensor:
|
||||
"""Process a single image array."""
|
||||
# Convert to tensor
|
||||
img_tensor = torch.from_numpy(img)
|
||||
|
||||
# Add batch dimension if needed
|
||||
if img_tensor.ndim == 3:
|
||||
img_tensor = img_tensor.unsqueeze(0)
|
||||
|
||||
# Validate image format
|
||||
_, h, w, c = img_tensor.shape
|
||||
if not (c < h and c < w):
|
||||
raise ValueError(f"Expected channel-last images, but got shape {img_tensor.shape}")
|
||||
|
||||
if img_tensor.dtype != torch.uint8:
|
||||
raise ValueError(f"Expected torch.uint8 images, but got {img_tensor.dtype}")
|
||||
|
||||
# Convert to channel-first format
|
||||
img_tensor = einops.rearrange(img_tensor, "b h w c -> b c h w").contiguous()
|
||||
|
||||
# Convert to float32 and normalize to [0, 1]
|
||||
img_tensor = img_tensor.type(torch.float32) / 255.0
|
||||
|
||||
return img_tensor
|
||||
|
||||
def _process_observation(self, observation):
|
||||
"""
|
||||
Processes both image and state observations.
|
||||
"""
|
||||
|
||||
processed_obs = observation.copy()
|
||||
|
||||
if "pixels" in processed_obs:
|
||||
pixels = processed_obs.pop("pixels")
|
||||
|
||||
if isinstance(pixels, dict):
|
||||
imgs = {f"{OBS_IMAGES}.{key}": img for key, img in pixels.items()}
|
||||
else:
|
||||
imgs = {OBS_IMAGE: pixels}
|
||||
|
||||
for imgkey, img in imgs.items():
|
||||
processed_obs[imgkey] = self._process_single_image(img)
|
||||
|
||||
if "environment_state" in processed_obs:
|
||||
env_state_np = processed_obs.pop("environment_state")
|
||||
env_state = torch.from_numpy(env_state_np).float()
|
||||
if env_state.dim() == 1:
|
||||
env_state = env_state.unsqueeze(0)
|
||||
processed_obs[OBS_ENV_STATE] = env_state
|
||||
|
||||
if "agent_pos" in processed_obs:
|
||||
agent_pos_np = processed_obs.pop("agent_pos")
|
||||
agent_pos = torch.from_numpy(agent_pos_np).float()
|
||||
if agent_pos.dim() == 1:
|
||||
agent_pos = agent_pos.unsqueeze(0)
|
||||
processed_obs[OBS_STATE] = agent_pos
|
||||
|
||||
return processed_obs
|
||||
|
||||
def observation(self, observation):
|
||||
return self._process_observation(observation)
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms feature keys to a standardized contract.
|
||||
|
||||
This method handles several renaming patterns:
|
||||
- Exact matches (e.g., 'pixels' -> 'OBS_IMAGE').
|
||||
- Prefixed exact matches (e.g., 'observation.pixels' -> 'OBS_IMAGE').
|
||||
- Prefix matches (e.g., 'pixels.cam1' -> 'OBS_IMAGES.cam1').
|
||||
- Prefixed prefix matches (e.g., 'observation.pixels.cam1' -> 'OBS_IMAGES.cam1').
|
||||
- environment_state -> OBS_ENV_STATE,
|
||||
- agent_pos -> OBS_STATE,
|
||||
- observation.environment_state -> OBS_ENV_STATE,
|
||||
- observation.agent_pos -> OBS_STATE
|
||||
"""
|
||||
exact_pairs = {
|
||||
"pixels": OBS_IMAGE,
|
||||
"environment_state": OBS_ENV_STATE,
|
||||
"agent_pos": OBS_STATE,
|
||||
}
|
||||
|
||||
prefix_pairs = {
|
||||
"pixels.": f"{OBS_IMAGES}.",
|
||||
}
|
||||
|
||||
for key in list(features.keys()):
|
||||
matched_prefix = False
|
||||
for old_prefix, new_prefix in prefix_pairs.items():
|
||||
prefixed_old = f"observation.{old_prefix}"
|
||||
if key.startswith(prefixed_old):
|
||||
suffix = key[len(prefixed_old) :]
|
||||
features[f"{new_prefix}{suffix}"] = features.pop(key)
|
||||
matched_prefix = True
|
||||
break
|
||||
|
||||
if key.startswith(old_prefix):
|
||||
suffix = key[len(old_prefix) :]
|
||||
features[f"{new_prefix}{suffix}"] = features.pop(key)
|
||||
matched_prefix = True
|
||||
break
|
||||
|
||||
if matched_prefix:
|
||||
continue
|
||||
|
||||
for old, new in exact_pairs.items():
|
||||
if key == old or key == f"observation.{old}":
|
||||
if key in features:
|
||||
features[new] = features.pop(key)
|
||||
break
|
||||
|
||||
return features
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,51 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from lerobot.configs.types import PolicyFeature
|
||||
from lerobot.processor.pipeline import (
|
||||
ObservationProcessor,
|
||||
ProcessorStepRegistry,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="rename_processor")
|
||||
class RenameProcessor(ObservationProcessor):
|
||||
"""Rename processor that renames keys in the observation."""
|
||||
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def observation(self, observation):
|
||||
processed_obs = {}
|
||||
for key, value in observation.items():
|
||||
if key in self.rename_map:
|
||||
processed_obs[self.rename_map[key]] = value
|
||||
else:
|
||||
processed_obs[key] = value
|
||||
|
||||
return processed_obs
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"rename_map": self.rename_map}
|
||||
|
||||
def feature_contract(self, features: dict[str, PolicyFeature]) -> dict[str, PolicyFeature]:
|
||||
"""Transforms:
|
||||
- Each key in the observation that appears in `rename_map` is renamed to its value.
|
||||
- Keys not in `rename_map` remain unchanged.
|
||||
"""
|
||||
return {self.rename_map.get(k, k): v for k, v in features.items()}
|
||||
@@ -393,5 +393,9 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
return dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
record()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -112,5 +112,9 @@ def replay(cfg: ReplayConfig):
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
replay()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -501,6 +501,10 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
logging.info("End of eval")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
init_logging()
|
||||
eval_main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -286,6 +286,10 @@ def train(cfg: TrainPipelineConfig):
|
||||
policy.push_model_to_hub(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
init_logging()
|
||||
train()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -80,5 +80,9 @@ def setup_motors(cfg: SetupConfig):
|
||||
device.setup_motors()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
setup_motors()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -153,5 +153,9 @@ def teleoperate(cfg: TeleoperateConfig):
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
def main():
|
||||
teleoperate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -19,6 +19,7 @@ import traceback
|
||||
import pytest
|
||||
from serial import SerialException
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from tests.utils import DEVICE
|
||||
|
||||
# Import fixture modules as plugins
|
||||
@@ -69,3 +70,19 @@ def patch_builtins_input(monkeypatch):
|
||||
print(text)
|
||||
|
||||
monkeypatch.setattr("builtins.input", print_text)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def policy_feature_factory():
|
||||
"""PolicyFeature factory"""
|
||||
|
||||
def _pf(ft: FeatureType, shape: tuple[int, ...]) -> PolicyFeature:
|
||||
return PolicyFeature(type=ft, shape=shape)
|
||||
|
||||
return _pf
|
||||
|
||||
|
||||
def assert_contract_is_typed(features: dict[str, PolicyFeature]) -> None:
|
||||
assert isinstance(features, dict)
|
||||
assert all(isinstance(k, str) for k in features.keys())
|
||||
assert all(isinstance(v, PolicyFeature) for v in features.values())
|
||||
|
||||
@@ -27,11 +27,13 @@ from lerobot import available_policies
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.utils import cycle, dataset_to_policy_features
|
||||
from lerobot.envs.factory import make_env, make_env_config
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.act.configuration_act import ACTConfig
|
||||
from lerobot.policies.act.modeling_act import ACTTemporalEnsembler
|
||||
from lerobot.policies.factory import (
|
||||
get_policy_class,
|
||||
@@ -363,6 +365,54 @@ def test_normalize(insert_temporal_dim):
|
||||
unnormalize(output_batch)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multikey", [True, False])
|
||||
def test_multikey_construction(multikey: bool):
|
||||
"""
|
||||
Asserts that multiple keys with type State/Action are correctly processed by the policy constructor,
|
||||
preventing erroneous creation of the policy object.
|
||||
"""
|
||||
input_features = {
|
||||
"observation.state": PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(10,),
|
||||
),
|
||||
}
|
||||
output_features = {
|
||||
"action": PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(5,),
|
||||
),
|
||||
}
|
||||
|
||||
if multikey:
|
||||
"""Simulates the complete state/action is constructed from more granular multiple
|
||||
keys, of the same type as the overall state/action"""
|
||||
input_features = {}
|
||||
input_features["observation.state.subset1"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
|
||||
input_features["observation.state.subset2"] = PolicyFeature(type=FeatureType.STATE, shape=(5,))
|
||||
input_features["observation.state"] = PolicyFeature(type=FeatureType.STATE, shape=(10,))
|
||||
|
||||
output_features = {}
|
||||
output_features["action.first_three_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(3,))
|
||||
output_features["action.last_two_motors"] = PolicyFeature(type=FeatureType.ACTION, shape=(2,))
|
||||
output_features["action"] = PolicyFeature(
|
||||
type=FeatureType.ACTION,
|
||||
shape=(5,),
|
||||
)
|
||||
|
||||
config = ACTConfig(input_features=input_features, output_features=output_features)
|
||||
|
||||
state_condition = config.robot_state_feature == input_features[OBS_STATE]
|
||||
action_condition = config.action_feature == output_features[ACTION]
|
||||
|
||||
assert state_condition, (
|
||||
f"Discrepancy detected. Robot state feature is {config.robot_state_feature} but policy expects {input_features[OBS_STATE]}"
|
||||
)
|
||||
assert action_condition, (
|
||||
f"Discrepancy detected. Action feature is {config.action_feature} but policy expects {output_features[ACTION]}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"ds_repo_id, policy_name, policy_kwargs, file_name_extra",
|
||||
[
|
||||
|
||||
@@ -0,0 +1,282 @@
|
||||
import torch
|
||||
|
||||
from lerobot.processor.pipeline import (
|
||||
RobotProcessor,
|
||||
TransitionKey,
|
||||
_default_batch_to_transition,
|
||||
_default_transition_to_batch,
|
||||
)
|
||||
|
||||
|
||||
def _dummy_batch():
|
||||
"""Create a dummy batch using the new format with observation.* and next.* keys."""
|
||||
return {
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.right": torch.randn(1, 3, 128, 128),
|
||||
"observation.state": torch.tensor([[0.1, 0.2, 0.3, 0.4]]),
|
||||
"action": torch.tensor([[0.5]]),
|
||||
"next.reward": 1.0,
|
||||
"next.done": False,
|
||||
"next.truncated": False,
|
||||
"info": {"key": "value"},
|
||||
}
|
||||
|
||||
|
||||
def test_observation_grouping_roundtrip():
|
||||
"""Test that observation.* keys are properly grouped and ungrouped."""
|
||||
proc = RobotProcessor([])
|
||||
batch_in = _dummy_batch()
|
||||
batch_out = proc(batch_in)
|
||||
|
||||
# Check that all observation.* keys are preserved
|
||||
original_obs_keys = {k: v for k, v in batch_in.items() if k.startswith("observation.")}
|
||||
reconstructed_obs_keys = {k: v for k, v in batch_out.items() if k.startswith("observation.")}
|
||||
|
||||
assert set(original_obs_keys.keys()) == set(reconstructed_obs_keys.keys())
|
||||
|
||||
# Check tensor values
|
||||
assert torch.allclose(batch_out["observation.image.left"], batch_in["observation.image.left"])
|
||||
assert torch.allclose(batch_out["observation.image.right"], batch_in["observation.image.right"])
|
||||
assert torch.allclose(batch_out["observation.state"], batch_in["observation.state"])
|
||||
|
||||
# Check other fields
|
||||
assert torch.allclose(batch_out["action"], batch_in["action"])
|
||||
assert batch_out["next.reward"] == batch_in["next.reward"]
|
||||
assert batch_out["next.done"] == batch_in["next.done"]
|
||||
assert batch_out["next.truncated"] == batch_in["next.truncated"]
|
||||
assert batch_out["info"] == batch_in["info"]
|
||||
|
||||
|
||||
def test_batch_to_transition_observation_grouping():
|
||||
"""Test that _default_batch_to_transition correctly groups observation.* keys."""
|
||||
batch = {
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
"observation.state": [1, 2, 3, 4],
|
||||
"action": "action_data",
|
||||
"next.reward": 1.5,
|
||||
"next.done": True,
|
||||
"next.truncated": False,
|
||||
"info": {"episode": 42},
|
||||
}
|
||||
|
||||
transition = _default_batch_to_transition(batch)
|
||||
|
||||
# Check observation is a dict with all observation.* keys
|
||||
assert isinstance(transition[TransitionKey.OBSERVATION], dict)
|
||||
assert "observation.image.top" in transition[TransitionKey.OBSERVATION]
|
||||
assert "observation.image.left" in transition[TransitionKey.OBSERVATION]
|
||||
assert "observation.state" in transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check values are preserved
|
||||
assert torch.allclose(
|
||||
transition[TransitionKey.OBSERVATION]["observation.image.top"], batch["observation.image.top"]
|
||||
)
|
||||
assert torch.allclose(
|
||||
transition[TransitionKey.OBSERVATION]["observation.image.left"], batch["observation.image.left"]
|
||||
)
|
||||
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
|
||||
|
||||
# Check other fields
|
||||
assert transition[TransitionKey.ACTION] == "action_data"
|
||||
assert transition[TransitionKey.REWARD] == 1.5
|
||||
assert transition[TransitionKey.DONE]
|
||||
assert not transition[TransitionKey.TRUNCATED]
|
||||
assert transition[TransitionKey.INFO] == {"episode": 42}
|
||||
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
|
||||
|
||||
def test_transition_to_batch_observation_flattening():
|
||||
"""Test that _default_transition_to_batch correctly flattens observation dict."""
|
||||
observation_dict = {
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
"observation.state": [1, 2, 3, 4],
|
||||
}
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: observation_dict,
|
||||
TransitionKey.ACTION: "action_data",
|
||||
TransitionKey.REWARD: 1.5,
|
||||
TransitionKey.DONE: True,
|
||||
TransitionKey.TRUNCATED: False,
|
||||
TransitionKey.INFO: {"episode": 42},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
}
|
||||
|
||||
batch = _default_transition_to_batch(transition)
|
||||
|
||||
# Check that observation.* keys are flattened back to batch
|
||||
assert "observation.image.top" in batch
|
||||
assert "observation.image.left" in batch
|
||||
assert "observation.state" in batch
|
||||
|
||||
# Check values are preserved
|
||||
assert torch.allclose(batch["observation.image.top"], observation_dict["observation.image.top"])
|
||||
assert torch.allclose(batch["observation.image.left"], observation_dict["observation.image.left"])
|
||||
assert batch["observation.state"] == [1, 2, 3, 4]
|
||||
|
||||
# Check other fields are mapped to next.* format
|
||||
assert batch["action"] == "action_data"
|
||||
assert batch["next.reward"] == 1.5
|
||||
assert batch["next.done"]
|
||||
assert not batch["next.truncated"]
|
||||
assert batch["info"] == {"episode": 42}
|
||||
|
||||
|
||||
def test_no_observation_keys():
|
||||
"""Test behavior when there are no observation.* keys."""
|
||||
batch = {
|
||||
"action": "action_data",
|
||||
"next.reward": 2.0,
|
||||
"next.done": False,
|
||||
"next.truncated": True,
|
||||
"info": {"test": "no_obs"},
|
||||
}
|
||||
|
||||
transition = _default_batch_to_transition(batch)
|
||||
|
||||
# Observation should be None when no observation.* keys
|
||||
assert transition[TransitionKey.OBSERVATION] is None
|
||||
|
||||
# Check other fields
|
||||
assert transition[TransitionKey.ACTION] == "action_data"
|
||||
assert transition[TransitionKey.REWARD] == 2.0
|
||||
assert not transition[TransitionKey.DONE]
|
||||
assert transition[TransitionKey.TRUNCATED]
|
||||
assert transition[TransitionKey.INFO] == {"test": "no_obs"}
|
||||
|
||||
# Round trip should work
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
assert reconstructed_batch["action"] == "action_data"
|
||||
assert reconstructed_batch["next.reward"] == 2.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
assert reconstructed_batch["next.truncated"]
|
||||
assert reconstructed_batch["info"] == {"test": "no_obs"}
|
||||
|
||||
|
||||
def test_minimal_batch():
|
||||
"""Test with minimal batch containing only observation.* and action."""
|
||||
batch = {"observation.state": "minimal_state", "action": "minimal_action"}
|
||||
|
||||
transition = _default_batch_to_transition(batch)
|
||||
|
||||
# Check observation
|
||||
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
|
||||
assert transition[TransitionKey.ACTION] == "minimal_action"
|
||||
|
||||
# Check defaults
|
||||
assert transition[TransitionKey.REWARD] == 0.0
|
||||
assert not transition[TransitionKey.DONE]
|
||||
assert not transition[TransitionKey.TRUNCATED]
|
||||
assert transition[TransitionKey.INFO] == {}
|
||||
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
|
||||
# Round trip
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
assert reconstructed_batch["observation.state"] == "minimal_state"
|
||||
assert reconstructed_batch["action"] == "minimal_action"
|
||||
assert reconstructed_batch["next.reward"] == 0.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
assert not reconstructed_batch["next.truncated"]
|
||||
assert reconstructed_batch["info"] == {}
|
||||
|
||||
|
||||
def test_empty_batch():
|
||||
"""Test behavior with empty batch."""
|
||||
batch = {}
|
||||
|
||||
transition = _default_batch_to_transition(batch)
|
||||
|
||||
# All fields should have defaults
|
||||
assert transition[TransitionKey.OBSERVATION] is None
|
||||
assert transition[TransitionKey.ACTION] is None
|
||||
assert transition[TransitionKey.REWARD] == 0.0
|
||||
assert not transition[TransitionKey.DONE]
|
||||
assert not transition[TransitionKey.TRUNCATED]
|
||||
assert transition[TransitionKey.INFO] == {}
|
||||
assert transition[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
|
||||
# Round trip
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
assert reconstructed_batch["action"] is None
|
||||
assert reconstructed_batch["next.reward"] == 0.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
assert not reconstructed_batch["next.truncated"]
|
||||
assert reconstructed_batch["info"] == {}
|
||||
|
||||
|
||||
def test_complex_nested_observation():
|
||||
"""Test with complex nested observation data."""
|
||||
batch = {
|
||||
"observation.image.top": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567890},
|
||||
"observation.image.left": {"image": torch.randn(1, 3, 128, 128), "timestamp": 1234567891},
|
||||
"observation.state": torch.randn(7),
|
||||
"action": torch.randn(8),
|
||||
"next.reward": 3.14,
|
||||
"next.done": False,
|
||||
"next.truncated": True,
|
||||
"info": {"episode_length": 200, "success": True},
|
||||
}
|
||||
|
||||
transition = _default_batch_to_transition(batch)
|
||||
reconstructed_batch = _default_transition_to_batch(transition)
|
||||
|
||||
# Check that all observation keys are preserved
|
||||
original_obs_keys = {k for k in batch if k.startswith("observation.")}
|
||||
reconstructed_obs_keys = {k for k in reconstructed_batch if k.startswith("observation.")}
|
||||
|
||||
assert original_obs_keys == reconstructed_obs_keys
|
||||
|
||||
# Check tensor values
|
||||
assert torch.allclose(batch["observation.state"], reconstructed_batch["observation.state"])
|
||||
|
||||
# Check nested dict with tensors
|
||||
assert torch.allclose(
|
||||
batch["observation.image.top"]["image"], reconstructed_batch["observation.image.top"]["image"]
|
||||
)
|
||||
assert torch.allclose(
|
||||
batch["observation.image.left"]["image"], reconstructed_batch["observation.image.left"]["image"]
|
||||
)
|
||||
|
||||
# Check action tensor
|
||||
assert torch.allclose(batch["action"], reconstructed_batch["action"])
|
||||
|
||||
# Check other fields
|
||||
assert batch["next.reward"] == reconstructed_batch["next.reward"]
|
||||
assert batch["next.done"] == reconstructed_batch["next.done"]
|
||||
assert batch["next.truncated"] == reconstructed_batch["next.truncated"]
|
||||
assert batch["info"] == reconstructed_batch["info"]
|
||||
|
||||
|
||||
def test_custom_converter():
|
||||
"""Test that custom converters can still be used."""
|
||||
|
||||
def to_tr(batch):
|
||||
# Custom converter that modifies the reward
|
||||
tr = _default_batch_to_transition(batch)
|
||||
# Double the reward
|
||||
reward = tr.get(TransitionKey.REWARD, 0.0)
|
||||
new_tr = tr.copy()
|
||||
new_tr[TransitionKey.REWARD] = reward * 2 if reward is not None else 0.0
|
||||
return new_tr
|
||||
|
||||
def to_batch(tr):
|
||||
batch = _default_transition_to_batch(tr)
|
||||
return batch
|
||||
|
||||
processor = RobotProcessor(steps=[], to_transition=to_tr, to_output=to_batch)
|
||||
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 4),
|
||||
"action": torch.randn(1, 2),
|
||||
"next.reward": 1.0,
|
||||
"next.done": False,
|
||||
}
|
||||
|
||||
result = processor(batch)
|
||||
|
||||
# Check the reward was doubled by our custom converter
|
||||
assert result["next.reward"] == 2.0
|
||||
assert torch.allclose(result["observation.state"], batch["observation.state"])
|
||||
assert torch.allclose(result["action"], batch["action"])
|
||||
@@ -0,0 +1,628 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from unittest.mock import Mock
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.processor.normalize_processor import (
|
||||
NormalizerProcessor,
|
||||
UnnormalizerProcessor,
|
||||
_convert_stats_to_tensors,
|
||||
)
|
||||
from lerobot.processor.pipeline import RobotProcessor, TransitionKey
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
|
||||
):
|
||||
"""Helper to create an EnvTransition dictionary."""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: reward,
|
||||
TransitionKey.DONE: done,
|
||||
TransitionKey.TRUNCATED: truncated,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
||||
}
|
||||
|
||||
|
||||
def test_numpy_conversion():
|
||||
stats = {
|
||||
"observation.image": {
|
||||
"mean": np.array([0.5, 0.5, 0.5]),
|
||||
"std": np.array([0.2, 0.2, 0.2]),
|
||||
}
|
||||
}
|
||||
tensor_stats = _convert_stats_to_tensors(stats)
|
||||
|
||||
assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor)
|
||||
assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor)
|
||||
assert torch.allclose(tensor_stats["observation.image"]["mean"], torch.tensor([0.5, 0.5, 0.5]))
|
||||
assert torch.allclose(tensor_stats["observation.image"]["std"], torch.tensor([0.2, 0.2, 0.2]))
|
||||
|
||||
|
||||
def test_tensor_conversion():
|
||||
stats = {
|
||||
"action": {
|
||||
"mean": torch.tensor([0.0, 0.0]),
|
||||
"std": torch.tensor([1.0, 1.0]),
|
||||
}
|
||||
}
|
||||
tensor_stats = _convert_stats_to_tensors(stats)
|
||||
|
||||
assert tensor_stats["action"]["mean"].dtype == torch.float32
|
||||
assert tensor_stats["action"]["std"].dtype == torch.float32
|
||||
|
||||
|
||||
def test_scalar_conversion():
|
||||
stats = {
|
||||
"reward": {
|
||||
"mean": 0.5,
|
||||
"std": 0.1,
|
||||
}
|
||||
}
|
||||
tensor_stats = _convert_stats_to_tensors(stats)
|
||||
|
||||
assert torch.allclose(tensor_stats["reward"]["mean"], torch.tensor(0.5))
|
||||
assert torch.allclose(tensor_stats["reward"]["std"], torch.tensor(0.1))
|
||||
|
||||
|
||||
def test_list_conversion():
|
||||
stats = {
|
||||
"observation.state": {
|
||||
"min": [0.0, -1.0, -2.0],
|
||||
"max": [1.0, 1.0, 2.0],
|
||||
}
|
||||
}
|
||||
tensor_stats = _convert_stats_to_tensors(stats)
|
||||
|
||||
assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0]))
|
||||
assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0]))
|
||||
|
||||
|
||||
def test_unsupported_type():
|
||||
stats = {
|
||||
"bad_key": {
|
||||
"mean": "string_value",
|
||||
}
|
||||
}
|
||||
with pytest.raises(TypeError, match="Unsupported type"):
|
||||
_convert_stats_to_tensors(stats)
|
||||
|
||||
|
||||
# Helper functions to create feature maps and norm maps
|
||||
def _create_observation_features():
|
||||
return {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
}
|
||||
|
||||
|
||||
def _create_observation_norm_map():
|
||||
return {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.STATE: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
|
||||
# Fixtures for observation normalisation tests using NormalizerProcessor
|
||||
@pytest.fixture
|
||||
def observation_stats():
|
||||
return {
|
||||
"observation.image": {
|
||||
"mean": np.array([0.5, 0.5, 0.5]),
|
||||
"std": np.array([0.2, 0.2, 0.2]),
|
||||
},
|
||||
"observation.state": {
|
||||
"min": np.array([0.0, -1.0]),
|
||||
"max": np.array([1.0, 1.0]),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def observation_normalizer(observation_stats):
|
||||
"""Return a NormalizerProcessor that only has observation stats (no action)."""
|
||||
features = _create_observation_features()
|
||||
norm_map = _create_observation_norm_map()
|
||||
return NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats)
|
||||
|
||||
|
||||
def test_mean_std_normalization(observation_normalizer):
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
normalized_transition = observation_normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check mean/std normalization
|
||||
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
|
||||
assert torch.allclose(normalized_obs["observation.image"], expected_image)
|
||||
|
||||
|
||||
def test_min_max_normalization(observation_normalizer):
|
||||
observation = {
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
normalized_transition = observation_normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check min/max normalization to [-1, 1]
|
||||
# For state[0]: 2 * (0.5 - 0.0) / (1.0 - 0.0) - 1 = 0.0
|
||||
# For state[1]: 2 * (0.0 - (-1.0)) / (1.0 - (-1.0)) - 1 = 0.0
|
||||
expected_state = torch.tensor([0.0, 0.0])
|
||||
assert torch.allclose(normalized_obs["observation.state"], expected_state, atol=1e-6)
|
||||
|
||||
|
||||
def test_selective_normalization(observation_stats):
|
||||
features = _create_observation_features()
|
||||
norm_map = _create_observation_norm_map()
|
||||
normalizer = NormalizerProcessor(
|
||||
features=features, norm_map=norm_map, stats=observation_stats, normalize_keys={"observation.image"}
|
||||
)
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
normalized_transition = normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
# Only image should be normalized
|
||||
assert torch.allclose(normalized_obs["observation.image"], (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2)
|
||||
# State should remain unchanged
|
||||
assert torch.allclose(normalized_obs["observation.state"], observation["observation.state"])
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_device_compatibility(observation_stats):
|
||||
features = _create_observation_features()
|
||||
norm_map = _create_observation_norm_map()
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=observation_stats)
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]).cuda(),
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
normalized_transition = normalizer(transition)
|
||||
normalized_obs = normalized_transition[TransitionKey.OBSERVATION]
|
||||
|
||||
assert normalized_obs["observation.image"].device.type == "cuda"
|
||||
|
||||
|
||||
def test_from_lerobot_dataset():
|
||||
# Mock dataset
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.meta.stats = {
|
||||
"observation.image": {"mean": [0.5], "std": [0.2]},
|
||||
"action": {"mean": [0.0], "std": [1.0]},
|
||||
}
|
||||
|
||||
features = {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (1,)),
|
||||
}
|
||||
norm_map = {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
|
||||
normalizer = NormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
|
||||
|
||||
# Both observation and action statistics should be present in tensor stats
|
||||
assert "observation.image" in normalizer._tensor_stats
|
||||
assert "action" in normalizer._tensor_stats
|
||||
|
||||
|
||||
def test_state_dict_save_load(observation_normalizer):
|
||||
# Save state
|
||||
state_dict = observation_normalizer.state_dict()
|
||||
|
||||
# Create new normalizer and load state
|
||||
features = _create_observation_features()
|
||||
norm_map = _create_observation_norm_map()
|
||||
new_normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={})
|
||||
new_normalizer.load_state_dict(state_dict)
|
||||
|
||||
# Test that it works the same
|
||||
observation = {"observation.image": torch.tensor([0.7, 0.5, 0.3])}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result1 = observation_normalizer(transition)[TransitionKey.OBSERVATION]
|
||||
result2 = new_normalizer(transition)[TransitionKey.OBSERVATION]
|
||||
|
||||
assert torch.allclose(result1["observation.image"], result2["observation.image"])
|
||||
|
||||
|
||||
# Fixtures for ActionUnnormalizer tests
|
||||
@pytest.fixture
|
||||
def action_stats_mean_std():
|
||||
return {
|
||||
"mean": np.array([0.0, 0.0, 0.0]),
|
||||
"std": np.array([1.0, 2.0, 0.5]),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def action_stats_min_max():
|
||||
return {
|
||||
"min": np.array([-1.0, -2.0, 0.0]),
|
||||
"max": np.array([1.0, 2.0, 1.0]),
|
||||
}
|
||||
|
||||
|
||||
def _create_action_features():
|
||||
return {
|
||||
"action": PolicyFeature(FeatureType.ACTION, (3,)),
|
||||
}
|
||||
|
||||
|
||||
def _create_action_norm_map_mean_std():
|
||||
return {
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
|
||||
|
||||
def _create_action_norm_map_min_max():
|
||||
return {
|
||||
FeatureType.ACTION: NormalizationMode.MIN_MAX,
|
||||
}
|
||||
|
||||
|
||||
def test_mean_std_unnormalization(action_stats_mean_std):
|
||||
features = _create_action_features()
|
||||
norm_map = _create_action_norm_map_mean_std()
|
||||
unnormalizer = UnnormalizerProcessor(
|
||||
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
|
||||
)
|
||||
|
||||
normalized_action = torch.tensor([1.0, -0.5, 2.0])
|
||||
transition = create_transition(action=normalized_action)
|
||||
|
||||
unnormalized_transition = unnormalizer(transition)
|
||||
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
|
||||
|
||||
# action * std + mean
|
||||
expected = torch.tensor([1.0 * 1.0 + 0.0, -0.5 * 2.0 + 0.0, 2.0 * 0.5 + 0.0])
|
||||
assert torch.allclose(unnormalized_action, expected)
|
||||
|
||||
|
||||
def test_min_max_unnormalization(action_stats_min_max):
|
||||
features = _create_action_features()
|
||||
norm_map = _create_action_norm_map_min_max()
|
||||
unnormalizer = UnnormalizerProcessor(
|
||||
features=features, norm_map=norm_map, stats={"action": action_stats_min_max}
|
||||
)
|
||||
|
||||
# Actions in [-1, 1]
|
||||
normalized_action = torch.tensor([0.0, -1.0, 1.0])
|
||||
transition = create_transition(action=normalized_action)
|
||||
|
||||
unnormalized_transition = unnormalizer(transition)
|
||||
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
|
||||
|
||||
# Map from [-1, 1] to [min, max]
|
||||
# (action + 1) / 2 * (max - min) + min
|
||||
expected = torch.tensor(
|
||||
[
|
||||
(0.0 + 1) / 2 * (1.0 - (-1.0)) + (-1.0), # 0.0
|
||||
(-1.0 + 1) / 2 * (2.0 - (-2.0)) + (-2.0), # -2.0
|
||||
(1.0 + 1) / 2 * (1.0 - 0.0) + 0.0, # 1.0
|
||||
]
|
||||
)
|
||||
assert torch.allclose(unnormalized_action, expected)
|
||||
|
||||
|
||||
def test_numpy_action_input(action_stats_mean_std):
|
||||
features = _create_action_features()
|
||||
norm_map = _create_action_norm_map_mean_std()
|
||||
unnormalizer = UnnormalizerProcessor(
|
||||
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
|
||||
)
|
||||
|
||||
normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32)
|
||||
transition = create_transition(action=normalized_action)
|
||||
|
||||
unnormalized_transition = unnormalizer(transition)
|
||||
unnormalized_action = unnormalized_transition[TransitionKey.ACTION]
|
||||
|
||||
assert isinstance(unnormalized_action, torch.Tensor)
|
||||
expected = torch.tensor([1.0, -1.0, 1.0])
|
||||
assert torch.allclose(unnormalized_action, expected)
|
||||
|
||||
|
||||
def test_none_action(action_stats_mean_std):
|
||||
features = _create_action_features()
|
||||
norm_map = _create_action_norm_map_mean_std()
|
||||
unnormalizer = UnnormalizerProcessor(
|
||||
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
|
||||
)
|
||||
|
||||
transition = create_transition()
|
||||
result = unnormalizer(transition)
|
||||
|
||||
# Should return transition unchanged
|
||||
assert result == transition
|
||||
|
||||
|
||||
def test_action_from_lerobot_dataset():
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.meta.stats = {"action": {"mean": [0.0], "std": [1.0]}}
|
||||
features = {"action": PolicyFeature(FeatureType.ACTION, (1,))}
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
||||
unnormalizer = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
|
||||
assert "mean" in unnormalizer._tensor_stats["action"]
|
||||
|
||||
|
||||
# Fixtures for NormalizerProcessor tests
|
||||
@pytest.fixture
|
||||
def full_stats():
|
||||
return {
|
||||
"observation.image": {
|
||||
"mean": np.array([0.5, 0.5, 0.5]),
|
||||
"std": np.array([0.2, 0.2, 0.2]),
|
||||
},
|
||||
"observation.state": {
|
||||
"min": np.array([0.0, -1.0]),
|
||||
"max": np.array([1.0, 1.0]),
|
||||
},
|
||||
"action": {
|
||||
"mean": np.array([0.0, 0.0]),
|
||||
"std": np.array([1.0, 2.0]),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _create_full_features():
|
||||
return {
|
||||
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
|
||||
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
|
||||
"action": PolicyFeature(FeatureType.ACTION, (2,)),
|
||||
}
|
||||
|
||||
|
||||
def _create_full_norm_map():
|
||||
return {
|
||||
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
|
||||
FeatureType.STATE: NormalizationMode.MIN_MAX,
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def normalizer_processor(full_stats):
|
||||
features = _create_full_features()
|
||||
norm_map = _create_full_norm_map()
|
||||
return NormalizerProcessor(features=features, norm_map=norm_map, stats=full_stats)
|
||||
|
||||
|
||||
def test_combined_normalization(normalizer_processor):
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
action = torch.tensor([1.0, -0.5])
|
||||
transition = create_transition(
|
||||
observation=observation,
|
||||
action=action,
|
||||
reward=1.0,
|
||||
done=False,
|
||||
truncated=False,
|
||||
info={},
|
||||
complementary_data={},
|
||||
)
|
||||
|
||||
processed_transition = normalizer_processor(transition)
|
||||
|
||||
# Check normalized observations
|
||||
processed_obs = processed_transition[TransitionKey.OBSERVATION]
|
||||
expected_image = (torch.tensor([0.7, 0.5, 0.3]) - 0.5) / 0.2
|
||||
assert torch.allclose(processed_obs["observation.image"], expected_image)
|
||||
|
||||
# Check normalized action
|
||||
processed_action = processed_transition[TransitionKey.ACTION]
|
||||
expected_action = torch.tensor([(1.0 - 0.0) / 1.0, (-0.5 - 0.0) / 2.0])
|
||||
assert torch.allclose(processed_action, expected_action)
|
||||
|
||||
# Check other fields remain unchanged
|
||||
assert processed_transition[TransitionKey.REWARD] == 1.0
|
||||
assert not processed_transition[TransitionKey.DONE]
|
||||
|
||||
|
||||
def test_processor_from_lerobot_dataset(full_stats):
|
||||
# Mock dataset
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.meta.stats = full_stats
|
||||
|
||||
features = _create_full_features()
|
||||
norm_map = _create_full_norm_map()
|
||||
|
||||
processor = NormalizerProcessor.from_lerobot_dataset(
|
||||
mock_dataset, features, norm_map, normalize_keys={"observation.image"}
|
||||
)
|
||||
|
||||
assert processor.normalize_keys == {"observation.image"}
|
||||
assert "observation.image" in processor._tensor_stats
|
||||
assert "action" in processor._tensor_stats
|
||||
|
||||
|
||||
def test_get_config(full_stats):
|
||||
features = _create_full_features()
|
||||
norm_map = _create_full_norm_map()
|
||||
processor = NormalizerProcessor(
|
||||
features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6
|
||||
)
|
||||
|
||||
config = processor.get_config()
|
||||
expected_config = {
|
||||
"normalize_keys": ["observation.image"],
|
||||
"eps": 1e-6,
|
||||
"features": {
|
||||
"observation.image": {"type": "VISUAL", "shape": (3, 96, 96)},
|
||||
"observation.state": {"type": "STATE", "shape": (2,)},
|
||||
"action": {"type": "ACTION", "shape": (2,)},
|
||||
},
|
||||
"norm_map": {
|
||||
"VISUAL": "MEAN_STD",
|
||||
"STATE": "MIN_MAX",
|
||||
"ACTION": "MEAN_STD",
|
||||
},
|
||||
}
|
||||
assert config == expected_config
|
||||
|
||||
|
||||
def test_integration_with_robot_processor(normalizer_processor):
|
||||
"""Test integration with RobotProcessor pipeline"""
|
||||
robot_processor = RobotProcessor([normalizer_processor])
|
||||
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
action = torch.tensor([1.0, -0.5])
|
||||
transition = create_transition(
|
||||
observation=observation,
|
||||
action=action,
|
||||
reward=1.0,
|
||||
done=False,
|
||||
truncated=False,
|
||||
info={},
|
||||
complementary_data={},
|
||||
)
|
||||
|
||||
processed_transition = robot_processor(transition)
|
||||
|
||||
# Verify the processing worked
|
||||
assert isinstance(processed_transition[TransitionKey.OBSERVATION], dict)
|
||||
assert isinstance(processed_transition[TransitionKey.ACTION], torch.Tensor)
|
||||
|
||||
|
||||
# Edge case tests
|
||||
def test_empty_observation():
|
||||
stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
|
||||
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
|
||||
transition = create_transition()
|
||||
result = normalizer(transition)
|
||||
|
||||
assert result == transition
|
||||
|
||||
|
||||
def test_empty_stats():
|
||||
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats={})
|
||||
observation = {"observation.image": torch.tensor([0.5])}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = normalizer(transition)
|
||||
# Should return observation unchanged since no stats are available
|
||||
assert torch.allclose(
|
||||
result[TransitionKey.OBSERVATION]["observation.image"], observation["observation.image"]
|
||||
)
|
||||
|
||||
|
||||
def test_partial_stats():
|
||||
"""If statistics are incomplete, the value should pass through unchanged."""
|
||||
stats = {"observation.image": {"mean": [0.5]}} # Missing std / (min,max)
|
||||
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
normalizer = NormalizerProcessor(features=features, norm_map=norm_map, stats=stats)
|
||||
observation = {"observation.image": torch.tensor([0.7])}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
processed = normalizer(transition)[TransitionKey.OBSERVATION]
|
||||
assert torch.allclose(processed["observation.image"], observation["observation.image"])
|
||||
|
||||
|
||||
def test_missing_action_stats_no_error():
|
||||
mock_dataset = Mock()
|
||||
mock_dataset.meta.stats = {"observation.image": {"mean": [0.5], "std": [0.2]}}
|
||||
|
||||
features = {"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96))}
|
||||
norm_map = {FeatureType.VISUAL: NormalizationMode.MEAN_STD}
|
||||
|
||||
processor = UnnormalizerProcessor.from_lerobot_dataset(mock_dataset, features, norm_map)
|
||||
# The tensor stats should not contain the 'action' key
|
||||
assert "action" not in processor._tensor_stats
|
||||
|
||||
|
||||
def test_serialization_roundtrip(full_stats):
|
||||
"""Test that features and norm_map can be serialized and deserialized correctly."""
|
||||
features = _create_full_features()
|
||||
norm_map = _create_full_norm_map()
|
||||
original_processor = NormalizerProcessor(
|
||||
features=features, norm_map=norm_map, stats=full_stats, normalize_keys={"observation.image"}, eps=1e-6
|
||||
)
|
||||
|
||||
# Get config (serialization)
|
||||
config = original_processor.get_config()
|
||||
|
||||
# Create a new processor from the config (deserialization)
|
||||
new_processor = NormalizerProcessor(
|
||||
features=config["features"],
|
||||
norm_map=config["norm_map"],
|
||||
stats=full_stats,
|
||||
normalize_keys=set(config["normalize_keys"]),
|
||||
eps=config["eps"],
|
||||
)
|
||||
|
||||
# Test that both processors work the same way
|
||||
observation = {
|
||||
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
|
||||
"observation.state": torch.tensor([0.5, 0.0]),
|
||||
}
|
||||
action = torch.tensor([1.0, -0.5])
|
||||
transition = create_transition(
|
||||
observation=observation,
|
||||
action=action,
|
||||
reward=1.0,
|
||||
done=False,
|
||||
truncated=False,
|
||||
info={},
|
||||
complementary_data={},
|
||||
)
|
||||
|
||||
result1 = original_processor(transition)
|
||||
result2 = new_processor(transition)
|
||||
|
||||
# Compare results
|
||||
assert torch.allclose(
|
||||
result1[TransitionKey.OBSERVATION]["observation.image"],
|
||||
result2[TransitionKey.OBSERVATION]["observation.image"],
|
||||
)
|
||||
assert torch.allclose(result1[TransitionKey.ACTION], result2[TransitionKey.ACTION])
|
||||
|
||||
# Verify features and norm_map are correctly reconstructed
|
||||
assert new_processor.features.keys() == original_processor.features.keys()
|
||||
for key in new_processor.features:
|
||||
assert new_processor.features[key].type == original_processor.features[key].type
|
||||
assert new_processor.features[key].shape == original_processor.features[key].shape
|
||||
|
||||
assert new_processor.norm_map == original_processor.norm_map
|
||||
@@ -0,0 +1,486 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor import VanillaObservationProcessor
|
||||
from lerobot.processor.pipeline import TransitionKey
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
|
||||
):
|
||||
"""Helper to create an EnvTransition dictionary."""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: reward,
|
||||
TransitionKey.DONE: done,
|
||||
TransitionKey.TRUNCATED: truncated,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
||||
}
|
||||
|
||||
|
||||
def test_process_single_image():
|
||||
"""Test processing a single image."""
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Create a mock image (H, W, C) format, uint8
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
|
||||
observation = {"pixels": image}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that the image was processed correctly
|
||||
assert "observation.image" in processed_obs
|
||||
processed_img = processed_obs["observation.image"]
|
||||
|
||||
# Check shape: should be (1, 3, 64, 64) - batch, channels, height, width
|
||||
assert processed_img.shape == (1, 3, 64, 64)
|
||||
|
||||
# Check dtype and range
|
||||
assert processed_img.dtype == torch.float32
|
||||
assert processed_img.min() >= 0.0
|
||||
assert processed_img.max() <= 1.0
|
||||
|
||||
|
||||
def test_process_image_dict():
|
||||
"""Test processing multiple images in a dictionary."""
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Create mock images
|
||||
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
||||
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
|
||||
|
||||
observation = {"pixels": {"camera1": image1, "camera2": image2}}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that both images were processed
|
||||
assert "observation.images.camera1" in processed_obs
|
||||
assert "observation.images.camera2" in processed_obs
|
||||
|
||||
# Check shapes
|
||||
assert processed_obs["observation.images.camera1"].shape == (1, 3, 32, 32)
|
||||
assert processed_obs["observation.images.camera2"].shape == (1, 3, 48, 48)
|
||||
|
||||
|
||||
def test_process_batched_image():
|
||||
"""Test processing already batched images."""
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Create a batched image (B, H, W, C)
|
||||
image = np.random.randint(0, 256, size=(2, 64, 64, 3), dtype=np.uint8)
|
||||
|
||||
observation = {"pixels": image}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that batch dimension is preserved
|
||||
assert processed_obs["observation.image"].shape == (2, 3, 64, 64)
|
||||
|
||||
|
||||
def test_invalid_image_format():
|
||||
"""Test error handling for invalid image formats."""
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Test wrong channel order (channels first)
|
||||
image = np.random.randint(0, 256, size=(3, 64, 64), dtype=np.uint8)
|
||||
observation = {"pixels": image}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
with pytest.raises(ValueError, match="Expected channel-last images"):
|
||||
processor(transition)
|
||||
|
||||
|
||||
def test_invalid_image_dtype():
|
||||
"""Test error handling for invalid image dtype."""
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Test wrong dtype
|
||||
image = np.random.rand(64, 64, 3).astype(np.float32)
|
||||
observation = {"pixels": image}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
with pytest.raises(ValueError, match="Expected torch.uint8 images"):
|
||||
processor(transition)
|
||||
|
||||
|
||||
def test_no_pixels_in_observation():
|
||||
"""Test processor when no pixels are in observation."""
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
observation = {"other_data": np.array([1, 2, 3])}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Should preserve other data unchanged
|
||||
assert "other_data" in processed_obs
|
||||
np.testing.assert_array_equal(processed_obs["other_data"], np.array([1, 2, 3]))
|
||||
|
||||
|
||||
def test_none_observation():
|
||||
"""Test processor with None observation."""
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
transition = create_transition()
|
||||
result = processor(transition)
|
||||
|
||||
assert result == transition
|
||||
|
||||
|
||||
def test_serialization_methods():
|
||||
"""Test serialization methods."""
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Test get_config
|
||||
config = processor.get_config()
|
||||
assert isinstance(config, dict)
|
||||
|
||||
# Test state_dict
|
||||
state = processor.state_dict()
|
||||
assert isinstance(state, dict)
|
||||
|
||||
# Test load_state_dict (should not raise)
|
||||
processor.load_state_dict(state)
|
||||
|
||||
# Test reset (should not raise)
|
||||
processor.reset()
|
||||
|
||||
|
||||
def test_process_environment_state():
|
||||
"""Test processing environment_state."""
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
observation = {"environment_state": env_state}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that environment_state was renamed and processed
|
||||
assert "observation.environment_state" in processed_obs
|
||||
assert "environment_state" not in processed_obs
|
||||
|
||||
processed_state = processed_obs["observation.environment_state"]
|
||||
assert processed_state.shape == (1, 3) # Batch dimension added
|
||||
assert processed_state.dtype == torch.float32
|
||||
torch.testing.assert_close(processed_state, torch.tensor([[1.0, 2.0, 3.0]]))
|
||||
|
||||
|
||||
def test_process_agent_pos():
|
||||
"""Test processing agent_pos."""
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
|
||||
observation = {"agent_pos": agent_pos}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that agent_pos was renamed and processed
|
||||
assert "observation.state" in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
|
||||
processed_state = processed_obs["observation.state"]
|
||||
assert processed_state.shape == (1, 3) # Batch dimension added
|
||||
assert processed_state.dtype == torch.float32
|
||||
torch.testing.assert_close(processed_state, torch.tensor([[0.5, -0.5, 1.0]]))
|
||||
|
||||
|
||||
def test_process_batched_states():
|
||||
"""Test processing already batched states."""
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
env_state = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
|
||||
agent_pos = np.array([[0.5, -0.5], [1.0, -1.0]], dtype=np.float32)
|
||||
|
||||
observation = {"environment_state": env_state, "agent_pos": agent_pos}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that batch dimensions are preserved
|
||||
assert processed_obs["observation.environment_state"].shape == (2, 2)
|
||||
assert processed_obs["observation.state"].shape == (2, 2)
|
||||
|
||||
|
||||
def test_process_both_states():
|
||||
"""Test processing both environment_state and agent_pos."""
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
env_state = np.array([1.0, 2.0], dtype=np.float32)
|
||||
agent_pos = np.array([0.5, -0.5], dtype=np.float32)
|
||||
|
||||
observation = {"environment_state": env_state, "agent_pos": agent_pos, "other_data": "keep_me"}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that both states were processed
|
||||
assert "observation.environment_state" in processed_obs
|
||||
assert "observation.state" in processed_obs
|
||||
|
||||
# Check that original keys were removed
|
||||
assert "environment_state" not in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
|
||||
# Check that other data was preserved
|
||||
assert processed_obs["other_data"] == "keep_me"
|
||||
|
||||
|
||||
def test_no_states_in_observation():
|
||||
"""Test processor when no states are in observation."""
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
observation = {"other_data": np.array([1, 2, 3])}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Should preserve data unchanged
|
||||
np.testing.assert_array_equal(processed_obs, observation)
|
||||
|
||||
|
||||
def test_complete_observation_processing():
|
||||
"""Test processing a complete observation with both images and states."""
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Create mock data
|
||||
image = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
||||
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
|
||||
|
||||
observation = {
|
||||
"pixels": image,
|
||||
"environment_state": env_state,
|
||||
"agent_pos": agent_pos,
|
||||
"other_data": "preserve_me",
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that image was processed
|
||||
assert "observation.image" in processed_obs
|
||||
assert processed_obs["observation.image"].shape == (1, 3, 32, 32)
|
||||
|
||||
# Check that states were processed
|
||||
assert "observation.environment_state" in processed_obs
|
||||
assert "observation.state" in processed_obs
|
||||
|
||||
# Check that original keys were removed
|
||||
assert "pixels" not in processed_obs
|
||||
assert "environment_state" not in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
|
||||
# Check that other data was preserved
|
||||
assert processed_obs["other_data"] == "preserve_me"
|
||||
|
||||
|
||||
def test_image_only_processing():
|
||||
"""Test processing observation with only images."""
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
observation = {"pixels": image}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
assert "observation.image" in processed_obs
|
||||
assert len(processed_obs) == 1
|
||||
|
||||
|
||||
def test_state_only_processing():
|
||||
"""Test processing observation with only states."""
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
|
||||
observation = {"agent_pos": agent_pos}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
assert "observation.state" in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
|
||||
|
||||
def test_empty_observation():
|
||||
"""Test processing empty observation."""
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
observation = {}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
assert processed_obs == {}
|
||||
|
||||
|
||||
def test_equivalent_to_original_function():
|
||||
"""Test that ObservationProcessor produces equivalent results to preprocess_observation."""
|
||||
# Import the original function for comparison
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Create test data similar to what the original function expects
|
||||
image = np.random.randint(0, 256, size=(64, 64, 3), dtype=np.uint8)
|
||||
env_state = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
agent_pos = np.array([0.5, -0.5, 1.0], dtype=np.float32)
|
||||
|
||||
observation = {"pixels": image, "environment_state": env_state, "agent_pos": agent_pos}
|
||||
|
||||
# Process with original function
|
||||
original_result = preprocess_observation(observation)
|
||||
|
||||
# Process with new processor
|
||||
transition = create_transition(observation=observation)
|
||||
processor_result = processor(transition)[TransitionKey.OBSERVATION]
|
||||
|
||||
# Compare results
|
||||
assert set(original_result.keys()) == set(processor_result.keys())
|
||||
|
||||
for key in original_result:
|
||||
torch.testing.assert_close(original_result[key], processor_result[key])
|
||||
|
||||
|
||||
def test_equivalent_with_image_dict():
|
||||
"""Test equivalence with dictionary of images."""
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
|
||||
processor = VanillaObservationProcessor()
|
||||
|
||||
# Create test data with multiple cameras
|
||||
image1 = np.random.randint(0, 256, size=(32, 32, 3), dtype=np.uint8)
|
||||
image2 = np.random.randint(0, 256, size=(48, 48, 3), dtype=np.uint8)
|
||||
agent_pos = np.array([1.0, 2.0], dtype=np.float32)
|
||||
|
||||
observation = {"pixels": {"cam1": image1, "cam2": image2}, "agent_pos": agent_pos}
|
||||
|
||||
# Process with original function
|
||||
original_result = preprocess_observation(observation)
|
||||
|
||||
# Process with new processor
|
||||
transition = create_transition(observation=observation)
|
||||
processor_result = processor(transition)[TransitionKey.OBSERVATION]
|
||||
|
||||
# Compare results
|
||||
assert set(original_result.keys()) == set(processor_result.keys())
|
||||
|
||||
for key in original_result:
|
||||
torch.testing.assert_close(original_result[key], processor_result[key])
|
||||
|
||||
|
||||
def test_image_processor_feature_contract_pixels_to_image(policy_feature_factory):
|
||||
processor = VanillaObservationProcessor()
|
||||
features = {
|
||||
"pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
}
|
||||
out = processor.feature_contract(features.copy())
|
||||
|
||||
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["pixels"]
|
||||
assert "pixels" not in out
|
||||
assert out["keep"] == features["keep"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_image_processor_feature_contract_observation_pixels_to_image(policy_feature_factory):
|
||||
processor = VanillaObservationProcessor()
|
||||
features = {
|
||||
"observation.pixels": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
}
|
||||
out = processor.feature_contract(features.copy())
|
||||
|
||||
assert OBS_IMAGE in out and out[OBS_IMAGE] == features["observation.pixels"]
|
||||
assert "observation.pixels" not in out
|
||||
assert out["keep"] == features["keep"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_image_processor_feature_contract_multi_camera_and_prefixed(policy_feature_factory):
|
||||
processor = VanillaObservationProcessor()
|
||||
features = {
|
||||
"pixels.front": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"pixels.wrist": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"observation.pixels.rear": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (7,)),
|
||||
}
|
||||
out = processor.feature_contract(features.copy())
|
||||
|
||||
assert f"{OBS_IMAGES}.front" in out and out[f"{OBS_IMAGES}.front"] == features["pixels.front"]
|
||||
assert f"{OBS_IMAGES}.wrist" in out and out[f"{OBS_IMAGES}.wrist"] == features["pixels.wrist"]
|
||||
assert f"{OBS_IMAGES}.rear" in out and out[f"{OBS_IMAGES}.rear"] == features["observation.pixels.rear"]
|
||||
assert "pixels.front" not in out and "pixels.wrist" not in out and "observation.pixels.rear" not in out
|
||||
assert out["keep"] == features["keep"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_state_processor_feature_contract_environment_and_agent_pos(policy_feature_factory):
|
||||
processor = VanillaObservationProcessor()
|
||||
features = {
|
||||
"environment_state": policy_feature_factory(FeatureType.STATE, (3,)),
|
||||
"agent_pos": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||
"keep": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
}
|
||||
out = processor.feature_contract(features.copy())
|
||||
|
||||
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["environment_state"]
|
||||
assert OBS_STATE in out and out[OBS_STATE] == features["agent_pos"]
|
||||
assert "environment_state" not in out and "agent_pos" not in out
|
||||
assert out["keep"] == features["keep"]
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_state_processor_feature_contract_prefixed_inputs(policy_feature_factory):
|
||||
proc = VanillaObservationProcessor()
|
||||
features = {
|
||||
"observation.environment_state": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
"observation.agent_pos": policy_feature_factory(FeatureType.STATE, (4,)),
|
||||
}
|
||||
out = proc.feature_contract(features.copy())
|
||||
|
||||
assert OBS_ENV_STATE in out and out[OBS_ENV_STATE] == features["observation.environment_state"]
|
||||
assert OBS_STATE in out and out[OBS_STATE] == features["observation.agent_pos"]
|
||||
assert "environment_state" not in out and "agent_pos" not in out
|
||||
assert_contract_is_typed(out)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,467 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType
|
||||
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
|
||||
):
|
||||
"""Helper to create an EnvTransition dictionary."""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: reward,
|
||||
TransitionKey.DONE: done,
|
||||
TransitionKey.TRUNCATED: truncated,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
||||
}
|
||||
|
||||
|
||||
def test_basic_renaming():
|
||||
"""Test basic key renaming functionality."""
|
||||
rename_map = {
|
||||
"old_key1": "new_key1",
|
||||
"old_key2": "new_key2",
|
||||
}
|
||||
processor = RenameProcessor(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"old_key1": torch.tensor([1.0, 2.0]),
|
||||
"old_key2": np.array([3.0, 4.0]),
|
||||
"unchanged_key": "keep_me",
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check renamed keys
|
||||
assert "new_key1" in processed_obs
|
||||
assert "new_key2" in processed_obs
|
||||
assert "old_key1" not in processed_obs
|
||||
assert "old_key2" not in processed_obs
|
||||
|
||||
# Check values are preserved
|
||||
torch.testing.assert_close(processed_obs["new_key1"], torch.tensor([1.0, 2.0]))
|
||||
np.testing.assert_array_equal(processed_obs["new_key2"], np.array([3.0, 4.0]))
|
||||
|
||||
# Check unchanged key is preserved
|
||||
assert processed_obs["unchanged_key"] == "keep_me"
|
||||
|
||||
|
||||
def test_empty_rename_map():
|
||||
"""Test processor with empty rename map (should pass through unchanged)."""
|
||||
processor = RenameProcessor(rename_map={})
|
||||
|
||||
observation = {
|
||||
"key1": torch.tensor([1.0]),
|
||||
"key2": "value2",
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# All keys should be unchanged
|
||||
assert processed_obs.keys() == observation.keys()
|
||||
torch.testing.assert_close(processed_obs["key1"], observation["key1"])
|
||||
assert processed_obs["key2"] == observation["key2"]
|
||||
|
||||
|
||||
def test_none_observation():
|
||||
"""Test processor with None observation."""
|
||||
processor = RenameProcessor(rename_map={"old": "new"})
|
||||
|
||||
transition = create_transition()
|
||||
result = processor(transition)
|
||||
|
||||
# Should return transition unchanged
|
||||
assert result == transition
|
||||
|
||||
|
||||
def test_overlapping_rename():
|
||||
"""Test renaming when new names might conflict."""
|
||||
rename_map = {
|
||||
"a": "b",
|
||||
"b": "c", # This creates a potential conflict
|
||||
}
|
||||
processor = RenameProcessor(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"a": 1,
|
||||
"b": 2,
|
||||
"x": 3,
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that renaming happens correctly
|
||||
assert "a" not in processed_obs
|
||||
assert processed_obs["b"] == 1 # 'a' renamed to 'b'
|
||||
assert processed_obs["c"] == 2 # original 'b' renamed to 'c'
|
||||
assert processed_obs["x"] == 3
|
||||
|
||||
|
||||
def test_partial_rename():
|
||||
"""Test renaming only some keys."""
|
||||
rename_map = {
|
||||
"observation.state": "observation.proprio_state",
|
||||
"pixels": "observation.image",
|
||||
}
|
||||
processor = RenameProcessor(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"observation.state": torch.randn(10),
|
||||
"pixels": np.random.randint(0, 256, (64, 64, 3), dtype=np.uint8),
|
||||
"reward": 1.0,
|
||||
"info": {"episode": 1},
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check renamed keys
|
||||
assert "observation.proprio_state" in processed_obs
|
||||
assert "observation.image" in processed_obs
|
||||
assert "observation.state" not in processed_obs
|
||||
assert "pixels" not in processed_obs
|
||||
|
||||
# Check unchanged keys
|
||||
assert processed_obs["reward"] == 1.0
|
||||
assert processed_obs["info"] == {"episode": 1}
|
||||
|
||||
|
||||
def test_get_config():
|
||||
"""Test configuration serialization."""
|
||||
rename_map = {
|
||||
"old1": "new1",
|
||||
"old2": "new2",
|
||||
}
|
||||
processor = RenameProcessor(rename_map=rename_map)
|
||||
|
||||
config = processor.get_config()
|
||||
assert config == {"rename_map": rename_map}
|
||||
|
||||
|
||||
def test_state_dict():
|
||||
"""Test state dict (should be empty for RenameProcessor)."""
|
||||
processor = RenameProcessor(rename_map={"old": "new"})
|
||||
|
||||
state = processor.state_dict()
|
||||
assert state == {}
|
||||
|
||||
# Load state dict should work even with empty dict
|
||||
processor.load_state_dict({})
|
||||
|
||||
|
||||
def test_integration_with_robot_processor():
|
||||
"""Test integration with RobotProcessor pipeline."""
|
||||
rename_map = {
|
||||
"agent_pos": "observation.state",
|
||||
"pixels": "observation.image",
|
||||
}
|
||||
rename_processor = RenameProcessor(rename_map=rename_map)
|
||||
|
||||
pipeline = RobotProcessor([rename_processor])
|
||||
|
||||
observation = {
|
||||
"agent_pos": np.array([1.0, 2.0, 3.0]),
|
||||
"pixels": np.zeros((32, 32, 3), dtype=np.uint8),
|
||||
"other_data": "preserve_me",
|
||||
}
|
||||
transition = create_transition(
|
||||
observation=observation, reward=0.5, done=False, truncated=False, info={}, complementary_data={}
|
||||
)
|
||||
|
||||
result = pipeline(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check renaming worked through pipeline
|
||||
assert "observation.state" in processed_obs
|
||||
assert "observation.image" in processed_obs
|
||||
assert "agent_pos" not in processed_obs
|
||||
assert "pixels" not in processed_obs
|
||||
assert processed_obs["other_data"] == "preserve_me"
|
||||
|
||||
# Check other transition elements unchanged
|
||||
assert result[TransitionKey.REWARD] == 0.5
|
||||
assert result[TransitionKey.DONE] is False
|
||||
|
||||
|
||||
def test_save_and_load_pretrained():
|
||||
"""Test saving and loading processor with RobotProcessor."""
|
||||
rename_map = {
|
||||
"old_state": "observation.state",
|
||||
"old_image": "observation.image",
|
||||
}
|
||||
processor = RenameProcessor(rename_map=rename_map)
|
||||
pipeline = RobotProcessor([processor], name="TestRenameProcessor")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save pipeline
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Check files were created
|
||||
config_path = Path(tmp_dir) / "testrenameprocessor.json" # Based on name="TestRenameProcessor"
|
||||
assert config_path.exists()
|
||||
|
||||
# No state files should be created for RenameProcessor
|
||||
state_files = list(Path(tmp_dir).glob("*.safetensors"))
|
||||
assert len(state_files) == 0
|
||||
|
||||
# Load pipeline
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||
|
||||
assert loaded_pipeline.name == "TestRenameProcessor"
|
||||
assert len(loaded_pipeline) == 1
|
||||
|
||||
# Check that loaded processor works correctly
|
||||
loaded_processor = loaded_pipeline.steps[0]
|
||||
assert isinstance(loaded_processor, RenameProcessor)
|
||||
assert loaded_processor.rename_map == rename_map
|
||||
|
||||
# Test functionality after loading
|
||||
observation = {"old_state": [1, 2, 3], "old_image": "image_data"}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = loaded_pipeline(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
assert "observation.state" in processed_obs
|
||||
assert "observation.image" in processed_obs
|
||||
assert processed_obs["observation.state"] == [1, 2, 3]
|
||||
assert processed_obs["observation.image"] == "image_data"
|
||||
|
||||
|
||||
def test_registry_functionality():
|
||||
"""Test that RenameProcessor is properly registered."""
|
||||
# Check that it's registered
|
||||
assert "rename_processor" in ProcessorStepRegistry.list()
|
||||
|
||||
# Get from registry
|
||||
retrieved_class = ProcessorStepRegistry.get("rename_processor")
|
||||
assert retrieved_class is RenameProcessor
|
||||
|
||||
# Create instance from registry
|
||||
instance = retrieved_class(rename_map={"old": "new"})
|
||||
assert isinstance(instance, RenameProcessor)
|
||||
assert instance.rename_map == {"old": "new"}
|
||||
|
||||
|
||||
def test_registry_based_save_load():
|
||||
"""Test save/load using registry name instead of module path."""
|
||||
processor = RenameProcessor(rename_map={"key1": "renamed_key1"})
|
||||
pipeline = RobotProcessor([processor])
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save and load
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
|
||||
# Verify config uses registry name
|
||||
import json
|
||||
|
||||
with open(Path(tmp_dir) / "robotprocessor.json") as f: # Default name is "RobotProcessor"
|
||||
config = json.load(f)
|
||||
|
||||
assert "registry_name" in config["steps"][0]
|
||||
assert config["steps"][0]["registry_name"] == "rename_processor"
|
||||
assert "class" not in config["steps"][0] # Should use registry, not module path
|
||||
|
||||
# Load should work
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||
loaded_processor = loaded_pipeline.steps[0]
|
||||
assert isinstance(loaded_processor, RenameProcessor)
|
||||
assert loaded_processor.rename_map == {"key1": "renamed_key1"}
|
||||
|
||||
|
||||
def test_chained_rename_processors():
|
||||
"""Test multiple RenameProcessors in a pipeline."""
|
||||
# First processor: rename raw keys to intermediate format
|
||||
processor1 = RenameProcessor(
|
||||
rename_map={
|
||||
"pos": "agent_position",
|
||||
"img": "camera_image",
|
||||
}
|
||||
)
|
||||
|
||||
# Second processor: rename to final format
|
||||
processor2 = RenameProcessor(
|
||||
rename_map={
|
||||
"agent_position": "observation.state",
|
||||
"camera_image": "observation.image",
|
||||
}
|
||||
)
|
||||
|
||||
pipeline = RobotProcessor([processor1, processor2])
|
||||
|
||||
observation = {
|
||||
"pos": np.array([1.0, 2.0]),
|
||||
"img": "image_data",
|
||||
"extra": "keep_me",
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
# Step through to see intermediate results
|
||||
results = list(pipeline.step_through(transition))
|
||||
|
||||
# After first processor
|
||||
assert "agent_position" in results[1][TransitionKey.OBSERVATION]
|
||||
assert "camera_image" in results[1][TransitionKey.OBSERVATION]
|
||||
|
||||
# After second processor
|
||||
final_obs = results[2][TransitionKey.OBSERVATION]
|
||||
assert "observation.state" in final_obs
|
||||
assert "observation.image" in final_obs
|
||||
assert final_obs["extra"] == "keep_me"
|
||||
|
||||
# Original keys should be gone
|
||||
assert "pos" not in final_obs
|
||||
assert "img" not in final_obs
|
||||
assert "agent_position" not in final_obs
|
||||
assert "camera_image" not in final_obs
|
||||
|
||||
|
||||
def test_nested_observation_rename():
|
||||
"""Test renaming with nested observation structures."""
|
||||
rename_map = {
|
||||
"observation.images.left": "observation.camera.left_view",
|
||||
"observation.images.right": "observation.camera.right_view",
|
||||
"observation.proprio": "observation.proprioception",
|
||||
}
|
||||
processor = RenameProcessor(rename_map=rename_map)
|
||||
|
||||
observation = {
|
||||
"observation.images.left": torch.randn(3, 64, 64),
|
||||
"observation.images.right": torch.randn(3, 64, 64),
|
||||
"observation.proprio": torch.randn(7),
|
||||
"observation.gripper": torch.tensor([0.0]), # Not renamed
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check renames
|
||||
assert "observation.camera.left_view" in processed_obs
|
||||
assert "observation.camera.right_view" in processed_obs
|
||||
assert "observation.proprioception" in processed_obs
|
||||
|
||||
# Check unchanged key
|
||||
assert "observation.gripper" in processed_obs
|
||||
|
||||
# Check old keys removed
|
||||
assert "observation.images.left" not in processed_obs
|
||||
assert "observation.images.right" not in processed_obs
|
||||
assert "observation.proprio" not in processed_obs
|
||||
|
||||
|
||||
def test_value_types_preserved():
|
||||
"""Test that various value types are preserved during renaming."""
|
||||
rename_map = {"old_tensor": "new_tensor", "old_array": "new_array", "old_scalar": "new_scalar"}
|
||||
processor = RenameProcessor(rename_map=rename_map)
|
||||
|
||||
tensor_value = torch.randn(3, 3)
|
||||
array_value = np.random.rand(2, 2)
|
||||
|
||||
observation = {
|
||||
"old_tensor": tensor_value,
|
||||
"old_array": array_value,
|
||||
"old_scalar": 42,
|
||||
"old_string": "hello",
|
||||
"old_dict": {"nested": "value"},
|
||||
"old_list": [1, 2, 3],
|
||||
}
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that values and types are preserved
|
||||
assert torch.equal(processed_obs["new_tensor"], tensor_value)
|
||||
assert np.array_equal(processed_obs["new_array"], array_value)
|
||||
assert processed_obs["new_scalar"] == 42
|
||||
assert processed_obs["old_string"] == "hello"
|
||||
assert processed_obs["old_dict"] == {"nested": "value"}
|
||||
assert processed_obs["old_list"] == [1, 2, 3]
|
||||
|
||||
|
||||
def test_feature_contract_basic_renaming(policy_feature_factory):
|
||||
processor = RenameProcessor(rename_map={"a": "x", "b": "y"})
|
||||
features = {
|
||||
"a": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
"b": policy_feature_factory(FeatureType.ACTION, (3,)),
|
||||
"c": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
}
|
||||
|
||||
out = processor.feature_contract(features.copy())
|
||||
|
||||
# Values preserved and typed
|
||||
assert out["x"] == features["a"]
|
||||
assert out["y"] == features["b"]
|
||||
assert out["c"] == features["c"]
|
||||
|
||||
assert_contract_is_typed(out)
|
||||
# Input not mutated
|
||||
assert set(features) == {"a", "b", "c"}
|
||||
|
||||
|
||||
def test_feature_contract_overlapping_keys(policy_feature_factory):
|
||||
# Overlapping renames: both 'a' and 'b' exist. 'a'->'b', 'b'->'c'
|
||||
processor = RenameProcessor(rename_map={"a": "b", "b": "c"})
|
||||
features = {
|
||||
"a": policy_feature_factory(FeatureType.STATE, (1,)),
|
||||
"b": policy_feature_factory(FeatureType.STATE, (2,)),
|
||||
}
|
||||
out = processor.feature_contract(features)
|
||||
|
||||
assert set(out) == {"b", "c"}
|
||||
assert out["b"] == features["a"] # 'a' renamed to'b'
|
||||
assert out["c"] == features["b"] # 'b' renamed to 'c'
|
||||
assert_contract_is_typed(out)
|
||||
|
||||
|
||||
def test_feature_contract_chained_processors(policy_feature_factory):
|
||||
# Chain two rename processors at the contract level
|
||||
processor1 = RenameProcessor(rename_map={"pos": "agent_position", "img": "camera_image"})
|
||||
processor2 = RenameProcessor(
|
||||
rename_map={"agent_position": "observation.state", "camera_image": "observation.image"}
|
||||
)
|
||||
pipeline = RobotProcessor([processor1, processor2])
|
||||
|
||||
spec = {
|
||||
"pos": policy_feature_factory(FeatureType.STATE, (7,)),
|
||||
"img": policy_feature_factory(FeatureType.VISUAL, (3, 64, 64)),
|
||||
"extra": policy_feature_factory(FeatureType.ENV, (1,)),
|
||||
}
|
||||
out = pipeline.feature_contract(initial_features=spec)
|
||||
|
||||
assert set(out) == {"observation.state", "observation.image", "extra"}
|
||||
assert out["observation.state"] == spec["pos"]
|
||||
assert out["observation.image"] == spec["img"]
|
||||
assert out["extra"] == spec["extra"]
|
||||
assert_contract_is_typed(out)
|
||||
Reference in New Issue
Block a user