From 0cd84cc9f9eb8912386c9beed5a57812de6753e6 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Mon, 30 Jun 2025 16:22:57 +0200 Subject: [PATCH] fix(rebase) reverting files to main --- .gitattributes | 3 +- .gitignore | 8 +- .pre-commit-config.yaml | 10 +- Makefile | 38 ++++ README.md | 67 ++++-- benchmarks/video/capture_camera_feed.py | 2 +- examples/3_train_policy.py | 2 +- examples/4_train_policy_with_script.md | 8 +- lerobot/common/policies/act/modeling_act.py | 44 ++-- lerobot/common/policies/factory.py | 21 ++ .../policies/pi0fast/modeling_pi0fast.py | 27 ++- .../common/policies/tdmpc/modeling_tdmpc.py | 88 +++---- lerobot/common/robots/viperx/README.md | 182 +++++++++++++++ lerobot/common/utils/control_utils.py | 215 ++++++++++++++++++ lerobot/configs/parser.py | 1 - lerobot/configs/policies.py | 22 +- lerobot/find_port.py | 65 ++++++ pyproject.toml | 35 ++- .../policies/save_policy_to_safetensors.py | 2 +- tests/conftest.py | 19 +- tests/optim/test_optimizers.py | 192 +++++++++++++++- 21 files changed, 911 insertions(+), 140 deletions(-) diff --git a/.gitattributes b/.gitattributes index 44e16cf1d..7d89f37b2 100644 --- a/.gitattributes +++ b/.gitattributes @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - *.memmap filter=lfs diff=lfs merge=lfs -text *.stl filter=lfs diff=lfs merge=lfs -text *.safetensors filter=lfs diff=lfs merge=lfs -text *.mp4 filter=lfs diff=lfs merge=lfs -text *.arrow filter=lfs diff=lfs merge=lfs -text *.json !text !filter !merge !diff +tests/artifacts/cameras/*.png filter=lfs diff=lfs merge=lfs -text +*.bag filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index d6c51c90d..4ab886933 100644 --- a/.gitignore +++ b/.gitignore @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +# Dev scripts +.dev + # Logging logs tmp @@ -26,6 +29,7 @@ outputs # VS Code .vscode +.devcontainer # HPC nautilus/*.yaml @@ -91,10 +95,8 @@ coverage.xml .hypothesis/ .pytest_cache/ -# Ignore .cache except calibration +# Ignore .cache .cache/* -!.cache/calibration/ -!.cache/calibration/** # Translations *.mo diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a778ce0e9..e1f971d39 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,18 +37,18 @@ repos: - id: trailing-whitespace - repo: https://github.com/adhtruong/mirrors-typos - rev: v1.31.1 + rev: v1.33.1 hooks: - id: typos args: [--force-exclude] - repo: https://github.com/asottile/pyupgrade - rev: v3.19.1 + rev: v3.20.0 hooks: - id: pyupgrade - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.5 + rev: v0.11.13 hooks: - id: ruff args: [--fix] @@ -57,12 +57,12 @@ repos: ##### Security ##### - repo: https://github.com/gitleaks/gitleaks - rev: v8.24.3 + rev: v8.27.2 hooks: - id: gitleaks - repo: https://github.com/woodruffw/zizmor-pre-commit - rev: v1.5.2 + rev: v1.9.0 hooks: - id: zizmor diff --git a/Makefile b/Makefile index c82483cc3..9457dbe6e 100644 --- a/Makefile +++ b/Makefile @@ -40,6 +40,8 @@ test-end-to-end: ${MAKE} DEVICE=$(DEVICE) test-diffusion-ete-eval ${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-train ${MAKE} DEVICE=$(DEVICE) test-tdmpc-ete-eval + ${MAKE} DEVICE=$(DEVICE) test-smolvla-ete-train + ${MAKE} DEVICE=$(DEVICE) test-smolvla-ete-eval test-act-ete-train: python lerobot/scripts/train.py \ @@ -48,6 +50,7 @@ test-act-ete-train: --policy.n_action_steps=20 \ --policy.chunk_size=20 \ --policy.device=$(DEVICE) \ + --policy.push_to_hub=false \ --env.type=aloha \ --env.episode_length=5 \ --dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \ @@ -85,6 +88,7 @@ test-diffusion-ete-train: --policy.diffusion_step_embed_dim=32 \ --policy.num_inference_steps=10 \ --policy.device=$(DEVICE) \ + --policy.push_to_hub=false \ --env.type=pusht \ --env.episode_length=5 \ --dataset.repo_id=lerobot/pusht \ @@ -114,6 +118,7 @@ test-tdmpc-ete-train: python lerobot/scripts/train.py \ --policy.type=tdmpc \ --policy.device=$(DEVICE) \ + --policy.push_to_hub=false \ --env.type=xarm \ --env.task=XarmLift-v0 \ --env.episode_length=5 \ @@ -140,3 +145,36 @@ test-tdmpc-ete-eval: --env.task=XarmLift-v0 \ --eval.n_episodes=1 \ --eval.batch_size=1 + + +test-smolvla-ete-train: + python lerobot/scripts/train.py \ + --policy.type=smolvla \ + --policy.n_action_steps=20 \ + --policy.chunk_size=20 \ + --policy.device=$(DEVICE) \ + --policy.push_to_hub=false \ + --env.type=aloha \ + --env.episode_length=5 \ + --dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \ + --dataset.image_transforms.enable=true \ + --dataset.episodes="[0]" \ + --batch_size=2 \ + --steps=4 \ + --eval_freq=2 \ + --eval.n_episodes=1 \ + --eval.batch_size=1 \ + --save_freq=2 \ + --save_checkpoint=true \ + --log_freq=1 \ + --wandb.enable=false \ + --output_dir=tests/outputs/smolvla/ + +test-smolvla-ete-eval: + python lerobot/scripts/eval.py \ + --policy.path=tests/outputs/smolvla/checkpoints/000004/pretrained_model \ + --policy.device=$(DEVICE) \ + --env.type=aloha \ + --env.episode_length=5 \ + --eval.n_episodes=1 \ + --eval.batch_size=1 diff --git a/README.md b/README.md index 42e0ee2c5..09398021e 100644 --- a/README.md +++ b/README.md @@ -23,22 +23,36 @@

-

- Build Your Own SO-100 Robot!

+

+ Build Your Own SO-101 Robot!

- SO-100 leader and follower arms +
+ SO-101 follower arm + SO-101 leader arm +
-

Meet the SO-100 โ€“ Just $110 per arm!

+ +

Meet the updated SO100, the SO-101 โ€“ Just โ‚ฌ114 per arm!

Train it in minutes with a few simple moves on your laptop.

Then sit back and watch your creation act autonomously! ๐Ÿคฏ

-

- Get the full SO-100 tutorial here.

+

+ See the full SO-101 tutorial here.

-

Want to take it to the next level? Make your SO-100 mobile by building LeKiwi!

-

Check out the LeKiwi tutorial and bring your robot to life on wheels.

+

Want to take it to the next level? Make your SO-101 mobile by building LeKiwi!

+

Check out the LeKiwi tutorial and bring your robot to life on wheels.

LeKiwi mobile robot
@@ -51,7 +65,6 @@ --- - ๐Ÿค— LeRobot aims to provide models, datasets, and tools for real-world robotics in PyTorch. The goal is to lower the barrier to entry to robotics so that everyone can contribute and benefit from sharing datasets and pretrained models. ๐Ÿค— LeRobot contains state-of-the-art approaches that have been shown to transfer to the real-world with a focus on imitation learning and reinforcement learning. @@ -77,6 +90,7 @@ ### Acknowledgment +- The LeRobot team ๐Ÿค— for building SmolVLA [Paper](https://arxiv.org/abs/2506.01844), [Blog](https://huggingface.co/blog/smolvla). - Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io). - Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io). - Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM). @@ -116,7 +130,7 @@ pip install -e . ``` > **NOTE:** If you encounter build errors, you may need to install additional dependencies (`cmake`, `build-essential`, and `ffmpeg libs`). On Linux, run: -`sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg) +`sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev`. For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg) For simulations, ๐Ÿค— LeRobot comes with gymnasium environments that can be installed as extras: - [aloha](https://github.com/huggingface/gym-aloha) @@ -198,7 +212,6 @@ Under the hood, the `LeRobotDataset` format makes use of several ways to seriali Here are the important details and internal structure organization of a typical `LeRobotDataset` instantiated with `dataset = LeRobotDataset("lerobot/aloha_static_coffee")`. The exact features will change from dataset to dataset but not the main aspects: ``` -TODO: IMPROVE dataset attributes: โ”œ hf_dataset: a Hugging Face dataset (backed by Arrow/parquet). Typical features example: โ”‚ โ”œ observation.images.cam_high (VideoFrame): @@ -209,9 +222,9 @@ dataset attributes: โ”‚ โ”œ episode_index (int64): index of the episode for this sample โ”‚ โ”œ frame_index (int64): index of the frame for this sample in the episode ; starts at 0 for each episode โ”‚ โ”œ timestamp (float32): timestamp in the episode - โ”‚ โ”œ next.done (bool): indicates the end of en episode ; True for the last frame in each episode + โ”‚ โ”œ next.done (bool): indicates the end of an episode ; True for the last frame in each episode โ”‚ โ”” index (int64): general index in the whole dataset - โ”œ meta: contains 2 tensors with the start and end indices of each episode + โ”œ episode_data_index: contains 2 tensors with the start and end indices of each episode โ”‚ โ”œ from (1D int64 tensor): first frame index for each episode โ€” shape (num episodes,) starts with 0 โ”‚ โ”” to: (1D int64 tensor): last frame index for each episode โ€” shape (num episodes,) โ”œ stats: a dictionary of statistics (max, mean, min, std) for each feature in the dataset, for instance @@ -258,7 +271,7 @@ See `python lerobot/scripts/eval.py --help` for more instructions. ### Train your own policy -Check out [example 3](./examples/3_train_policy.py) that illustrate how to train a model using our core library in python, and [example 4](./examples/4_train_policy_with_script.md) that shows how to use our training script from command line. +Check out [example 3](./examples/3_train_policy.py) that illustrates how to train a model using our core library in python, and [example 4](./examples/4_train_policy_with_script.md) that shows how to use our training script from command line. To use wandb for logging training and evaluation curves, make sure you've run `wandb login` as a one-time setup step. Then, when running the training command above, enable WandB in the configuration by adding `--wandb.enable=true`. @@ -309,7 +322,7 @@ Once you have trained a policy you may upload it to the Hugging Face hub using a You first need to find the checkpoint folder located inside your experiment directory (e.g. `outputs/train/2024-05-05/20-21-12_aloha_act_default/checkpoints/002500`). Within that there is a `pretrained_model` directory which should contain: - `config.json`: A serialized version of the policy configuration (following the policy's dataclass config). - `model.safetensors`: A set of `torch.nn.Module` parameters, saved in [Hugging Face Safetensors](https://huggingface.co/docs/safetensors/index) format. -- `train_config.json`: A consolidated configuration containing all parameter userd for training. The policy configuration should match `config.json` exactly. Thisis useful for anyone who wants to evaluate your policy or for reproducibility. +- `train_config.json`: A consolidated configuration containing all parameters used for training. The policy configuration should match `config.json` exactly. This is useful for anyone who wants to evaluate your policy or for reproducibility. To upload these to the hub, run the following: ```bash @@ -348,7 +361,7 @@ with profile( If you want, you can cite this work with: ```bibtex @misc{cadene2024lerobot, - author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Wolf, Thomas}, + author = {Cadene, Remi and Alibert, Simon and Soare, Alexander and Gallouedec, Quentin and Zouitine, Adil and Palma, Steven and Kooijmans, Pepijn and Aractingi, Michel and Shukor, Mustafa and Aubakirova, Dana and Russi, Martino and Capuano, Francesco and Pascale, Caroline and Choghari, Jade and Moss, Jess and Wolf, Thomas}, title = {LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch}, howpublished = "\url{https://github.com/huggingface/lerobot}", year = {2024} @@ -356,6 +369,15 @@ If you want, you can cite this work with: ``` Additionally, if you are using any of the particular policy architecture, pretrained models, or datasets, it is recommended to cite the original authors of the work as they appear below: +- [SmolVLA](https://arxiv.org/abs/2506.01844) +```bibtex +@article{shukor2025smolvla, + title={SmolVLA: A Vision-Language-Action Model for Affordable and Efficient Robotics}, + author={Shukor, Mustafa and Aubakirova, Dana and Capuano, Francesco and Kooijmans, Pepijn and Palma, Steven and Zouitine, Adil and Aractingi, Michel and Pascal, Caroline and Russi, Martino and Marafioti, Andres and Alibert, Simon and Cord, Matthieu and Wolf, Thomas and Cadene, Remi}, + journal={arXiv preprint arXiv:2506.01844}, + year={2025} +} +``` - [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) ```bibtex @@ -396,6 +418,19 @@ Additionally, if you are using any of the particular policy architecture, pretra year={2024} } ``` + + +- [HIL-SERL](https://hil-serl.github.io/) +```bibtex +@Article{luo2024hilserl, +title={Precise and Dexterous Robotic Manipulation via Human-in-the-Loop Reinforcement Learning}, +author={Jianlan Luo and Charles Xu and Jeffrey Wu and Sergey Levine}, +year={2024}, +eprint={2410.21845}, +archivePrefix={arXiv}, +primaryClass={cs.RO} +} +``` ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=huggingface/lerobot&type=Timeline)](https://star-history.com/#huggingface/lerobot&Timeline) diff --git a/benchmarks/video/capture_camera_feed.py b/benchmarks/video/capture_camera_feed.py index ce248f20b..8f8530532 100755 --- a/benchmarks/video/capture_camera_feed.py +++ b/benchmarks/video/capture_camera_feed.py @@ -55,7 +55,7 @@ def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height if not ret: print("Error: Could not read frame.") break - rr.log("video/stream", rr.Image(frame.numpy()), static=True) + rr.log("video/stream", rr.Image(frame), static=True) cv2.imwrite(str(capture_dir / f"frame_{frame_index:06d}.png"), frame) frame_index += 1 diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index 6c3af54ea..f9c251a02 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""This scripts demonstrates how to train Diffusion Policy on the PushT environment. +"""This script demonstrates how to train Diffusion Policy on the PushT environment. Once you have trained a model with this script, you can try to evaluate it on examples/2_evaluate_pretrained_policy.py diff --git a/examples/4_train_policy_with_script.md b/examples/4_train_policy_with_script.md index 0c11afe98..cb4cc6268 100644 --- a/examples/4_train_policy_with_script.md +++ b/examples/4_train_policy_with_script.md @@ -1,5 +1,5 @@ This tutorial will explain the training script, how to use it, and particularly how to configure everything needed for the training run. -> **Note:** The following assume you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--policy.device=cpu` (`--policy.device=mps` respectively). However, be advised that the code executes much slower on cpu. +> **Note:** The following assumes you're running these commands on a machine equipped with a cuda GPU. If you don't have one (or if you're using a Mac), you can add `--policy.device=cpu` (`--policy.device=mps` respectively). However, be advised that the code executes much slower on cpu. ## The training script @@ -23,7 +23,7 @@ def train(cfg: TrainPipelineConfig): You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option) -When running the script, inputs for the command line are parsed thanks to the `@parser.wrap()` decorator and an instance of this class is automatically generated. Under the hood, this is done with [Draccus](https://github.com/dlwh/draccus) which is a tool dedicated for this purpose. If you're familiar with Hydra, Draccus can similarly load configurations from config files (.json, .yaml) and also override their values through command line inputs. Unlike Hydra, these configurations are pre-defined in the code through dataclasses rather than being defined entirely in config files. This allows for more rigorous serialization/deserialization, typing, and to manipulate configuration as objects directly in the code and not as dictionaries or namespaces (which enables nice features in an IDE such as autocomplete, jump-to-def, etc.) +When running the script, inputs for the command line are parsed thanks to the `@parser.wrap()` decorator and an instance of this class is automatically generated. Under the hood, this is done with [Draccus](https://github.com/dlwh/draccus) which is a tool dedicated to this purpose. If you're familiar with Hydra, Draccus can similarly load configurations from config files (.json, .yaml) and also override their values through command line inputs. Unlike Hydra, these configurations are pre-defined in the code through dataclasses rather than being defined entirely in config files. This allows for more rigorous serialization/deserialization, typing, and to manipulate configuration as objects directly in the code and not as dictionaries or namespaces (which enables nice features in an IDE such as autocomplete, jump-to-def, etc.) Let's have a look at a simplified example. Amongst other attributes, the training config has the following attributes: ```python @@ -43,7 +43,7 @@ class DatasetConfig: ``` This creates a hierarchical relationship where, for example assuming we have a `cfg` instance of `TrainPipelineConfig`, we can access the `repo_id` value with `cfg.dataset.repo_id`. -From the command line, we can specify this value with using a very similar syntax `--dataset.repo_id=repo/id`. +From the command line, we can specify this value by using a very similar syntax `--dataset.repo_id=repo/id`. By default, every field takes its default value specified in the dataclass. If a field doesn't have a default value, it needs to be specified either from the command line or from a config file โ€“ which path is also given in the command line (more in this below). In the example above, the `dataset` field doesn't have a default value which means it must be specified. @@ -135,7 +135,7 @@ will start a training run with the same configuration used for training [lerobot ## Resume training -Being able to resume a training run is important in case it crashed or aborted for any reason. We'll demonstrate how to that here. +Being able to resume a training run is important in case it crashed or aborted for any reason. We'll demonstrate how to do that here. Let's reuse the command from the previous run and add a few more options: ```bash diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 72d4df03a..122066577 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -15,7 +15,7 @@ # limitations under the License. """Action Chunking Transformer Policy -As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://arxiv.org/abs/2304.13705). +As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://huggingface.co/papers/2304.13705). The majority of changes here involve removing unused code, unifying naming, and adding helpful comments. """ @@ -33,6 +33,7 @@ from torch import Tensor, nn from torchvision.models._utils import IntermediateLayerGetter from torchvision.ops.misc import FrozenBatchNorm2d +from lerobot.common.constants import ACTION, OBS_IMAGES from lerobot.common.policies.act.configuration_act import ACTConfig from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy @@ -41,7 +42,7 @@ from lerobot.common.policies.pretrained import PreTrainedPolicy class ACTPolicy(PreTrainedPolicy): """ Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost - Hardware (paper: https://arxiv.org/abs/2304.13705, code: https://github.com/tonyzhaozh/act) + Hardware (paper: https://huggingface.co/papers/2304.13705, code: https://github.com/tonyzhaozh/act) """ config_class = ACTConfig @@ -114,46 +115,49 @@ class ACTPolicy(PreTrainedPolicy): environment. It works by managing the actions in a queue and only calling `select_actions` when the queue is empty. """ - self.eval() + self.eval() # keeping the policy in eval mode as it could be set to train mode while queue is consumed - batch = self.normalize_inputs(batch) - if self.config.image_features: - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = [batch[key] for key in self.config.image_features] - - # If we are doing temporal ensembling, do online updates where we keep track of the number of actions - # we are ensembling over. if self.config.temporal_ensemble_coeff is not None: - actions = self.model(batch)[0] # (batch_size, chunk_size, action_dim) - actions = self.unnormalize_outputs({"action": actions})["action"] + actions = self.predict_action_chunk(batch) action = self.temporal_ensembler.update(actions) return action # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by # querying the policy. if len(self._action_queue) == 0: - actions = self.model(batch)[0][:, : self.config.n_action_steps] - - # TODO(rcadene): make _forward return output dictionary? - actions = self.unnormalize_outputs({"action": actions})["action"] + actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue # effectively has shape (n_action_steps, batch_size, *), hence the transpose. self._action_queue.extend(actions.transpose(0, 1)) return self._action_queue.popleft() + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + self.eval() + + batch = self.normalize_inputs(batch) + if self.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features] + + actions = self.model(batch)[0] + actions = self.unnormalize_outputs({ACTION: actions})[ACTION] + return actions + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = [batch[key] for key in self.config.image_features] + batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features] batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) l1_loss = ( - F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) + F.l1_loss(batch[ACTION], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) ).mean() loss_dict = {"l1_loss": l1_loss.item()} @@ -161,7 +165,7 @@ class ACTPolicy(PreTrainedPolicy): # Calculate Dโ‚–โ‚—(latent_pdf || standard_normal). Note: After computing the KL-divergence for # each dimension independently, we sum over the latent dimension to get the total # KL-divergence per batch element, then take the mean over the batch. - # (See App. B of https://arxiv.org/abs/1312.6114 for more details). + # (See App. B of https://huggingface.co/papers/1312.6114 for more details). mean_kld = ( (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() ) @@ -175,7 +179,7 @@ class ACTPolicy(PreTrainedPolicy): class ACTTemporalEnsembler: def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None: - """Temporal ensembling as described in Algorithm 2 of https://arxiv.org/abs/2304.13705. + """Temporal ensembling as described in Algorithm 2 of https://huggingface.co/papers/2304.13705. The weights are calculated as wแตข = exp(-temporal_ensemble_coeff * i) where wโ‚€ is the oldest action. They are then normalized to sum to 1 by dividing by ฮฃwแตข. Here's some intuition around how the diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 8def95a35..682bb8cee 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -27,6 +27,9 @@ from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionC from lerobot.common.policies.pi0.configuration_pi0 import PI0Config from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.policies.sac.configuration_sac import SACConfig +from lerobot.common.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig +from lerobot.common.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.common.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.configs.policies import PreTrainedConfig @@ -59,6 +62,18 @@ def get_policy_class(name: str) -> PreTrainedPolicy: from lerobot.common.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy return PI0FASTPolicy + elif name == "sac": + from lerobot.common.policies.sac.modeling_sac import SACPolicy + + return SACPolicy + elif name == "reward_classifier": + from lerobot.common.policies.sac.reward_model.modeling_classifier import Classifier + + return Classifier + elif name == "smolvla": + from lerobot.common.policies.smolvla.modeling_smolvla import SmolVLAPolicy + + return SmolVLAPolicy else: raise NotImplementedError(f"Policy with name {name} is not implemented.") @@ -76,6 +91,12 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return PI0Config(**kwargs) elif policy_type == "pi0fast": return PI0FASTConfig(**kwargs) + elif policy_type == "sac": + return SACConfig(**kwargs) + elif policy_type == "smolvla": + return SmolVLAConfig(**kwargs) + elif policy_type == "reward_classifier": + return RewardClassifierConfig(**kwargs) else: raise ValueError(f"Policy type '{policy_type}' is not available.") diff --git a/lerobot/common/policies/pi0fast/modeling_pi0fast.py b/lerobot/common/policies/pi0fast/modeling_pi0fast.py index 36aafce94..dbf5266b1 100644 --- a/lerobot/common/policies/pi0fast/modeling_pi0fast.py +++ b/lerobot/common/policies/pi0fast/modeling_pi0fast.py @@ -17,7 +17,7 @@ """ ฯ€0+FAST: Efficient Action Tokenization for Vision-Language-Action Models -[Paper](https://arxiv.org/abs/2501.09747) +[Paper](https://huggingface.co/papers/2501.09747) [Jax code](https://github.com/Physical-Intelligence/openpi) Designed by Physical Intelligence. Ported from Jax by Hugging Face. @@ -56,7 +56,7 @@ from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGe from transformers.cache_utils import HybridCache, StaticCache from transformers.models.auto import CONFIG_MAPPING -from lerobot.common.constants import ACTION, OBS_ROBOT +from lerobot.common.constants import ACTION, OBS_STATE from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig from lerobot.common.policies.pretrained import PreTrainedPolicy @@ -192,6 +192,11 @@ class PI0FASTPolicy(PreTrainedPolicy): actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) return actions + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + raise NotImplementedError("Currently not implemented for PI0FAST") + @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations. @@ -203,7 +208,7 @@ class PI0FASTPolicy(PreTrainedPolicy): self.eval() if self.config.adapt_to_pi_aloha: - batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) + batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) batch = self.normalize_inputs(batch) @@ -231,7 +236,7 @@ class PI0FASTPolicy(PreTrainedPolicy): def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: if self.config.adapt_to_pi_aloha: - batch[OBS_ROBOT] = self._pi_aloha_decode_state(batch[OBS_ROBOT]) + batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) batch = self.normalize_inputs(batch) batch = self.normalize_targets(batch) @@ -516,7 +521,7 @@ class PI0FAST(nn.Module): interpolate_like_pi=self.config.interpolate_like_pi, ) - # Normalize from range [0,1] to [-1,1] as expacted by siglip + # Normalize from range [0,1] to [-1,1] as expected by siglip img = img * 2.0 - 1.0 bsize = img.shape[0] @@ -677,12 +682,12 @@ class PI0FAST(nn.Module): return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids def forward(self, batch: dict[str, Tensor]): - device = batch[OBS_ROBOT].device + device = batch[OBS_STATE].device # TODO: keep like this or move to the policy .forward images, img_masks = self.prepare_images(batch) padded_outs = self.create_input_tokens( - state=batch[OBS_ROBOT], + state=batch[OBS_STATE], lang_text=batch["task"], actions=batch[ACTION], ) @@ -849,7 +854,7 @@ class PI0FAST(nn.Module): # TODO: keep like this or move to the policy .forward images, img_masks = self.prepare_images(batch) - padded_outs = self.create_input_tokens(state=batch[OBS_ROBOT], lang_text=batch["task"], actions=None) + padded_outs = self.create_input_tokens(state=batch[OBS_STATE], lang_text=batch["task"], actions=None) embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs( images, img_masks, @@ -878,7 +883,11 @@ class PI0FAST(nn.Module): return actions def embed_image(self, image: torch.Tensor): - return self.pi0_paligemma.get_image_features(image) + # Handle different transformers versions + if hasattr(self.pi0_paligemma, "get_image_features"): + return self.pi0_paligemma.get_image_features(image) + else: + return self.pi0_paligemma.model.get_image_features(image) def embed_inputs( self, diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index b46ae9030..4bb564f8f 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -17,8 +17,8 @@ """Implementation of Finetuning Offline World Models in the Real World. The comments in this code may sometimes refer to these references: - TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://arxiv.org/abs/2203.04955) - FOWM paper: Finetuning Offline World Models in the Real World (https://arxiv.org/abs/2310.16029) + TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://huggingface.co/papers/2203.04955) + FOWM paper: Finetuning Offline World Models in the Real World (https://huggingface.co/papers/2310.16029) """ # ruff: noqa: N806 @@ -35,7 +35,7 @@ import torch.nn as nn import torch.nn.functional as F # noqa: N812 from torch import Tensor -from lerobot.common.constants import OBS_ENV, OBS_ROBOT +from lerobot.common.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.pretrained import PreTrainedPolicy from lerobot.common.policies.tdmpc.configuration_tdmpc import TDMPCConfig @@ -110,52 +110,58 @@ class TDMPCPolicy(PreTrainedPolicy): # CEM for the next step. self._prev_mean: torch.Tensor | None = None + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues} + + # Remove the time dimensions as it is not handled yet. + for key in batch: + assert batch[key].shape[1] == 1 + batch[key] = batch[key][:, 0] + + # NOTE: Order of observations matters here. + encode_keys = [] + if self.config.image_features: + encode_keys.append(OBS_IMAGE) + if self.config.env_state_feature: + encode_keys.append(OBS_ENV_STATE) + encode_keys.append(OBS_STATE) + z = self.model.encode({k: batch[k] for k in encode_keys}) + if self.config.use_mpc: # noqa: SIM108 + actions = self.plan(z) # (horizon, batch, action_dim) + else: + # Plan with the policy (ฯ€) alone. This always returns one action so unsqueeze to get a + # sequence dimension like in the MPC branch. + actions = self.model.pi(z).unsqueeze(0) + + actions = torch.clamp(actions, -1, +1) + + actions = self.unnormalize_outputs({ACTION: actions})[ACTION] + return actions + @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations.""" batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.image"] = batch[next(iter(self.config.image_features))] + batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))] self._queues = populate_queues(self._queues, batch) # When the action queue is depleted, populate it again by querying the policy. - if len(self._queues["action"]) == 0: - batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues} - - # Remove the time dimensions as it is not handled yet. - for key in batch: - assert batch[key].shape[1] == 1 - batch[key] = batch[key][:, 0] - - # NOTE: Order of observations matters here. - encode_keys = [] - if self.config.image_features: - encode_keys.append("observation.image") - if self.config.env_state_feature: - encode_keys.append("observation.environment_state") - encode_keys.append("observation.state") - z = self.model.encode({k: batch[k] for k in encode_keys}) - if self.config.use_mpc: # noqa: SIM108 - actions = self.plan(z) # (horizon, batch, action_dim) - else: - # Plan with the policy (ฯ€) alone. This always returns one action so unsqueeze to get a - # sequence dimension like in the MPC branch. - actions = self.model.pi(z).unsqueeze(0) - - actions = torch.clamp(actions, -1, +1) - - actions = self.unnormalize_outputs({"action": actions})["action"] + if len(self._queues[ACTION]) == 0: + actions = self.predict_action_chunk(batch) if self.config.n_action_repeats > 1: for _ in range(self.config.n_action_repeats): - self._queues["action"].append(actions[0]) + self._queues[ACTION].append(actions[0]) else: # Action queue is (n_action_steps, batch_size, action_dim), so we transpose the action. - self._queues["action"].extend(actions[: self.config.n_action_steps]) + self._queues[ACTION].extend(actions[: self.config.n_action_steps]) - action = self._queues["action"].popleft() + action = self._queues[ACTION].popleft() return action @torch.no_grad() @@ -312,7 +318,7 @@ class TDMPCPolicy(PreTrainedPolicy): batch = self.normalize_inputs(batch) if self.config.image_features: batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.image"] = batch[next(iter(self.config.image_features))] + batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))] batch = self.normalize_targets(batch) info = {} @@ -322,15 +328,15 @@ class TDMPCPolicy(PreTrainedPolicy): if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1: batch[key] = batch[key].transpose(1, 0) - action = batch["action"] # (t, b, action_dim) - reward = batch["next.reward"] # (t, b) + action = batch[ACTION] # (t, b, action_dim) + reward = batch[REWARD] # (t, b) observations = {k: v for k, v in batch.items() if k.startswith("observation.")} # Apply random image augmentations. if self.config.image_features and self.config.max_random_shift_ratio > 0: - observations["observation.image"] = flatten_forward_unflatten( + observations[OBS_IMAGE] = flatten_forward_unflatten( partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio), - observations["observation.image"], + observations[OBS_IMAGE], ) # Get the current observation for predicting trajectories, and all future observations for use in @@ -340,7 +346,7 @@ class TDMPCPolicy(PreTrainedPolicy): current_observation[k] = observations[k][0] next_observations[k] = observations[k][1:] horizon, batch_size = next_observations[ - "observation.image" if self.config.image_features else "observation.environment_state" + OBS_IMAGE if self.config.image_features else OBS_ENV_STATE ].shape[:2] # Run latent rollout using the latent dynamics model and policy model. @@ -753,9 +759,9 @@ class TDMPCObservationEncoder(nn.Module): ) ) if self.config.env_state_feature: - feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV])) + feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV_STATE])) if self.config.robot_state_feature: - feat.append(self.state_enc_layers(obs_dict[OBS_ROBOT])) + feat.append(self.state_enc_layers(obs_dict[OBS_STATE])) return torch.stack(feat, dim=0).mean(0) diff --git a/lerobot/common/robots/viperx/README.md b/lerobot/common/robots/viperx/README.md index e69de29bb..be2a323b6 100644 --- a/lerobot/common/robots/viperx/README.md +++ b/lerobot/common/robots/viperx/README.md @@ -0,0 +1,182 @@ +This tutorial explains how to use [Aloha and Aloha 2 stationary](https://www.trossenrobotics.com/aloha-stationary) with LeRobot. + +## Setup + +Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/aloha_docs/2.0/getting_started/stationary/hardware_setup.html) for setting up the hardware and plugging the 4 arms and 4 cameras to your computer. + + +## Install LeRobot + +On your computer: + +1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install): +```bash +mkdir -p ~/miniconda3 +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh +bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 +rm ~/miniconda3/miniconda.sh +~/miniconda3/bin/conda init bash +``` + +2. Restart shell or `source ~/.bashrc` + +3. Create and activate a fresh conda environment for lerobot +```bash +conda create -y -n lerobot python=3.10 && conda activate lerobot +``` + +4. Clone LeRobot: +```bash +git clone https://github.com/huggingface/lerobot.git ~/lerobot +``` + +5. When using `miniconda`, install `ffmpeg` in your environment: +```bash +conda install ffmpeg -c conda-forge +``` + +6. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense): +```bash +cd ~/lerobot && pip install -e ".[dynamixel, intelrealsense]" +``` + +## Teleoperate + +**/!\ FOR SAFETY, READ THIS /!\** +Teleoperation consists in manually operating the leader arms to move the follower arms. Importantly: +1. Make sure your leader arms are in the same position as the follower arms, so that the follower arms don't move too fast to match the leader arms, +2. Our code assumes that your robot has been assembled following Trossen Robotics instructions. This allows us to skip calibration, as we use the pre-defined calibration files in `.cache/calibration/aloha_default`. If you replace a motor, make sure you follow the exact instructions from Trossen Robotics. + +By running the following code, you can start your first **SAFE** teleoperation: + +> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`. + +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=aloha \ + --robot.max_relative_target=5 \ + --control.type=teleoperate +``` + +By adding `--robot.max_relative_target=5`, we override the default value for `max_relative_target` defined in [`AlohaRobotConfig`](lerobot/common/robot_devices/robots/configs.py). It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot.max_relative_target=null` to the command line: +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=aloha \ + --robot.max_relative_target=null \ + --control.type=teleoperate +``` + +## Record a dataset + +Once you're familiar with teleoperation, you can record your first dataset with Aloha. + +If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens): +```bash +huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential +``` + +Store your Hugging Face repository name in a variable to run these commands: +```bash +HF_USER=$(huggingface-cli whoami | head -n 1) +echo $HF_USER +``` + +Record 2 episodes and upload your dataset to the hub: +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=aloha \ + --robot.max_relative_target=null \ + --control.type=record \ + --control.fps=30 \ + --control.single_task="Grasp a lego block and put it in the bin." \ + --control.repo_id=${HF_USER}/aloha_test \ + --control.tags='["tutorial"]' \ + --control.warmup_time_s=5 \ + --control.episode_time_s=30 \ + --control.reset_time_s=30 \ + --control.num_episodes=2 \ + --control.push_to_hub=true +``` + +## Visualize a dataset + +If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by: +```bash +echo ${HF_USER}/aloha_test +``` + +If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with: +```bash +python lerobot/scripts/visualize_dataset_html.py \ + --repo-id ${HF_USER}/aloha_test +``` + +## Replay an episode + +**/!\ FOR SAFETY, READ THIS /!\** +Replay consists in automatically replaying the sequence of actions (i.e. goal positions for your motors) recorded in a given dataset episode. Make sure the current initial position of your robot is similar to the one in your episode, so that your follower arms don't move too fast to go to the first goal positions. For safety, you might want to add `--robot.max_relative_target=5` to your command line as explained above. + +Now try to replay the first episode on your robot: +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=aloha \ + --robot.max_relative_target=null \ + --control.type=replay \ + --control.fps=30 \ + --control.repo_id=${HF_USER}/aloha_test \ + --control.episode=0 +``` + +## Train a policy + +To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: +```bash +python lerobot/scripts/train.py \ + --dataset.repo_id=${HF_USER}/aloha_test \ + --policy.type=act \ + --output_dir=outputs/train/act_aloha_test \ + --job_name=act_aloha_test \ + --policy.device=cuda \ + --wandb.enable=true +``` + +Let's explain it: +1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/aloha_test`. +2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../lerobot/common/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. +4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon. +5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. + +For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md) + +Training should take several hours. You will find checkpoints in `outputs/train/act_aloha_test/checkpoints`. + +## Evaluate your policy + +You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes: +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=aloha \ + --control.type=record \ + --control.fps=30 \ + --control.single_task="Grasp a lego block and put it in the bin." \ + --control.repo_id=${HF_USER}/eval_act_aloha_test \ + --control.tags='["tutorial"]' \ + --control.warmup_time_s=5 \ + --control.episode_time_s=30 \ + --control.reset_time_s=30 \ + --control.num_episodes=10 \ + --control.push_to_hub=true \ + --control.policy.path=outputs/train/act_aloha_test/checkpoints/last/pretrained_model \ + --control.num_image_writer_processes=1 +``` + +As you can see, it's almost the same command as previously used to record your training dataset. Two things changed: +1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_aloha_test`). +2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_aloha_test`). +3. We use `--control.num_image_writer_processes=1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constant 30 fps during inference. Feel free to explore different values for `--control.num_image_writer_processes`. + +## More + +Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explanation. + +If you have any question or need help, please reach out on Discord in the channel `#aloha-arm`. diff --git a/lerobot/common/utils/control_utils.py b/lerobot/common/utils/control_utils.py index e69de29bb..b66977a72 100644 --- a/lerobot/common/utils/control_utils.py +++ b/lerobot/common/utils/control_utils.py @@ -0,0 +1,215 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +######################################################################################## +# Utilities +######################################################################################## + + +import logging +import traceback +from contextlib import nullcontext +from copy import copy +from functools import cache + +import numpy as np +import torch +from deepdiff import DeepDiff +from termcolor import colored + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.utils import DEFAULT_FEATURES +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.robots import Robot + + +def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None): + log_items = [] + if episode_index is not None: + log_items.append(f"ep:{episode_index}") + if frame_index is not None: + log_items.append(f"frame:{frame_index}") + + def log_dt(shortname, dt_val_s): + nonlocal log_items, fps + info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)" + if fps is not None: + actual_fps = 1 / dt_val_s + if actual_fps < fps - 1: + info_str = colored(info_str, "yellow") + log_items.append(info_str) + + # total step time displayed in milliseconds and its frequency + log_dt("dt", dt_s) + + # TODO(aliberts): move robot-specific logs logic in robot.print_logs() + if not robot.robot_type.startswith("stretch"): + for name in robot.leader_arms: + key = f"read_leader_{name}_pos_dt_s" + if key in robot.logs: + log_dt("dtRlead", robot.logs[key]) + + for name in robot.follower_arms: + key = f"write_follower_{name}_goal_pos_dt_s" + if key in robot.logs: + log_dt("dtWfoll", robot.logs[key]) + + key = f"read_follower_{name}_pos_dt_s" + if key in robot.logs: + log_dt("dtRfoll", robot.logs[key]) + + for name in robot.cameras: + key = f"read_camera_{name}_dt_s" + if key in robot.logs: + log_dt(f"dtR{name}", robot.logs[key]) + + info_str = " ".join(log_items) + logging.info(info_str) + + +@cache +def is_headless(): + """Detects if python is running without a monitor.""" + try: + import pynput # noqa + + return False + except Exception: + print( + "Error trying to import pynput. Switching to headless mode. " + "As a result, the video stream from the cameras won't be shown, " + "and you won't be able to change the control flow with keyboards. " + "For more info, see traceback below.\n" + ) + traceback.print_exc() + print() + return True + + +def predict_action( + observation: dict[str, np.ndarray], + policy: PreTrainedPolicy, + device: torch.device, + use_amp: bool, + task: str | None = None, + robot_type: str | None = None, +): + observation = copy(observation) + with ( + torch.inference_mode(), + torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), + ): + # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension + for name in observation: + observation[name] = torch.from_numpy(observation[name]) + if "image" in name: + observation[name] = observation[name].type(torch.float32) / 255 + observation[name] = observation[name].permute(2, 0, 1).contiguous() + observation[name] = observation[name].unsqueeze(0) + observation[name] = observation[name].to(device) + + observation["task"] = task if task else "" + observation["robot_type"] = robot_type if robot_type else "" + + # Compute the next action with the policy + # based on the current observation + action = policy.select_action(observation) + + # Remove batch dimension + action = action.squeeze(0) + + # Move to cpu, if not already the case + action = action.to("cpu") + + return action + + +def init_keyboard_listener(): + # Allow to exit early while recording an episode or resetting the environment, + # by tapping the right arrow key '->'. This might require a sudo permission + # to allow your terminal to monitor keyboard events. + events = {} + events["exit_early"] = False + events["rerecord_episode"] = False + events["stop_recording"] = False + + if is_headless(): + logging.warning( + "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." + ) + listener = None + return listener, events + + # Only import pynput if not in a headless environment + from pynput import keyboard + + def on_press(key): + try: + if key == keyboard.Key.right: + print("Right arrow key pressed. Exiting loop...") + events["exit_early"] = True + elif key == keyboard.Key.left: + print("Left arrow key pressed. Exiting loop and rerecord the last episode...") + events["rerecord_episode"] = True + events["exit_early"] = True + elif key == keyboard.Key.esc: + print("Escape key pressed. Stopping data recording...") + events["stop_recording"] = True + events["exit_early"] = True + except Exception as e: + print(f"Error handling key press: {e}") + + listener = keyboard.Listener(on_press=on_press) + listener.start() + + return listener, events + + +def sanity_check_dataset_name(repo_id, policy_cfg): + _, dataset_name = repo_id.split("/") + # either repo_id doesnt start with "eval_" and there is no policy + # or repo_id starts with "eval_" and there is a policy + + # Check if dataset_name starts with "eval_" but policy is missing + if dataset_name.startswith("eval_") and policy_cfg is None: + raise ValueError( + f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})." + ) + + # Check if dataset_name does not start with "eval_" but policy is provided + if not dataset_name.startswith("eval_") and policy_cfg is not None: + raise ValueError( + f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy_cfg.type})." + ) + + +def sanity_check_dataset_robot_compatibility( + dataset: LeRobotDataset, robot: Robot, fps: int, features: dict +) -> None: + fields = [ + ("robot_type", dataset.meta.robot_type, robot.robot_type), + ("fps", dataset.fps, fps), + ("features", dataset.features, {**features, **DEFAULT_FEATURES}), + ] + + mismatches = [] + for field, dataset_value, present_value in fields: + diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]) + if diff: + mismatches.append(f"{field}: expected {present_value}, got {dataset_value}") + + if mismatches: + raise ValueError( + "Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches) + ) diff --git a/lerobot/configs/parser.py b/lerobot/configs/parser.py index 39e315152..f69b5a7fa 100644 --- a/lerobot/configs/parser.py +++ b/lerobot/configs/parser.py @@ -26,7 +26,6 @@ from lerobot.common.utils.utils import has_method PATH_KEY = "path" PLUGIN_DISCOVERY_SUFFIX = "discover_packages_path" -draccus.set_config_type("json") def get_cli_overrides(field_name: str, args: Sequence[str] | None = None) -> list[str] | None: diff --git a/lerobot/configs/policies.py b/lerobot/configs/policies.py index 022d1fb52..9e7f3dd56 100644 --- a/lerobot/configs/policies.py +++ b/lerobot/configs/policies.py @@ -60,6 +60,16 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # automatic gradient scaling is used. use_amp: bool = False + push_to_hub: bool = True + repo_id: str | None = None + + # Upload on private repository on the Hugging Face hub. + private: bool | None = None + # Add tags to your policy on the hub. + tags: list[str] | None = None + # Add tags to your policy on the hub. + license: str | None = None + def __post_init__(self): self.pretrained_path = None if not self.device or not is_torch_device_available(self.device): @@ -78,15 +88,18 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): def type(self) -> str: return self.get_choice_name(self.__class__) - @abc.abstractproperty + @property + @abc.abstractmethod def observation_delta_indices(self) -> list | None: raise NotImplementedError - @abc.abstractproperty + @property + @abc.abstractmethod def action_delta_indices(self) -> list | None: raise NotImplementedError - @abc.abstractproperty + @property + @abc.abstractmethod def reward_delta_indices(self) -> list | None: raise NotImplementedError @@ -173,4 +186,5 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus # something like --policy.path (in addition to --policy.type) cli_overrides = policy_kwargs.pop("cli_overrides", []) - return draccus.parse(cls, config_file, args=cli_overrides) + with draccus.config_type("json"): + return draccus.parse(cls, config_file, args=cli_overrides) diff --git a/lerobot/find_port.py b/lerobot/find_port.py index e69de29bb..cf0282507 100644 --- a/lerobot/find_port.py +++ b/lerobot/find_port.py @@ -0,0 +1,65 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Helper to find the USB port associated with your MotorsBus. + +Example: + +```shell +python -m lerobot.find_port +``` +""" + +import platform +import time +from pathlib import Path + + +def find_available_ports(): + from serial.tools import list_ports # Part of pyserial library + + if platform.system() == "Windows": + # List COM ports using pyserial + ports = [port.device for port in list_ports.comports()] + else: # Linux/macOS + # List /dev/tty* ports for Unix-based systems + ports = [str(path) for path in Path("/dev").glob("tty*")] + return ports + + +def find_port(): + print("Finding all available ports for the MotorsBus.") + ports_before = find_available_ports() + print("Ports before disconnecting:", ports_before) + + print("Remove the USB cable from your MotorsBus and press Enter when done.") + input() # Wait for user to disconnect the device + + time.sleep(0.5) # Allow some time for port to be released + ports_after = find_available_ports() + ports_diff = list(set(ports_before) - set(ports_after)) + + if len(ports_diff) == 1: + port = ports_diff[0] + print(f"The port of this MotorsBus is '{port}'") + print("Reconnect the USB cable.") + elif len(ports_diff) == 0: + raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).") + else: + raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).") + + +if __name__ == "__main__": + find_port() diff --git a/pyproject.toml b/pyproject.toml index 4b858634d..5bff0fca6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,11 +49,11 @@ dependencies = [ "datasets>=2.19.0", "deepdiff>=7.0.1", "diffusers>=0.27.2", - "draccus>=0.10.0", + "draccus==0.10.0", "einops>=0.8.0", "flask>=3.0.3", "gdown>=5.1.0", - "gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work + "gymnasium==0.29.1", # TODO(rcadene, aliberts): Make gym 1.0.0 work "h5py>=3.10.0", "huggingface-hub[hf-transfer,cli]>=0.27.1 ; python_version < '4.0'", "imageio[ffmpeg]>=2.34.0", @@ -62,11 +62,13 @@ dependencies = [ "omegaconf>=2.3.0", "opencv-python-headless>=4.9.0", "packaging>=24.2", - "av>=12.0.5", - "pymunk>=6.6.0", + "av>=14.2.0", + "pymunk>=6.6.0,<7.0.0", "pynput>=1.7.7", + "pyserial>=3.5", "pyzmq>=26.2.1", "rerun-sdk>=0.21.0", + "scipy>=1.14.0", "termcolor>=2.4.0", "torch>=2.2.1", "torchcodec>=0.2.1; sys_platform != 'win32' and (sys_platform != 'linux' or (platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l')) and (sys_platform != 'darwin' or platform_machine != 'x86_64')", @@ -77,22 +79,28 @@ dependencies = [ [project.optional-dependencies] aloha = ["gym-aloha>=0.1.1 ; python_version < '4.0'"] +docs = ["hf-doc-builder @ git+https://github.com/huggingface/doc-builder.git@main", "watchdog >= 6.0.0"] dev = ["pre-commit>=3.7.0", "debugpy>=1.8.1"] dora = [ "gym-dora @ git+https://github.com/dora-rs/dora-lerobot.git#subdirectory=gym_dora ; python_version < '4.0'", ] -dynamixel = ["dynamixel-sdk>=3.7.31", "pynput>=1.7.7"] -feetech = ["feetech-servo-sdk>=1.0.0", "pynput>=1.7.7"] -intelrealsense = ["pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'"] -pi0 = ["transformers>=4.48.0"] +dynamixel = ["dynamixel-sdk>=3.7.31"] +feetech = ["feetech-servo-sdk>=1.0.0"] +gamepad = ["pygame>=2.5.1", "hidapi>=0.14.0"] +intelrealsense = [ + "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'", + "pyrealsense2-macosx>=2.54 ; sys_platform == 'darwin'", +] +pi0 = ["transformers>=4.50.3"] +smolvla = ["transformers>=4.50.3", "num2words>=0.5.14", "accelerate>=1.7.0", "safetensors>=0.4.3"] pusht = ["gym-pusht>=0.1.5 ; python_version < '4.0'"] stretch = [ "hello-robot-stretch-body>=0.7.27 ; python_version < '4.0' and sys_platform == 'linux'", "pyrender @ git+https://github.com/mmatl/pyrender.git ; sys_platform == 'linux'", - "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'", - "pynput>=1.7.7", + "pyrealsense2>=2.55.1.6486 ; sys_platform != 'darwin'" ] -test = ["pytest>=8.1.0", "pytest-cov>=5.0.0", "pyserial>=3.5"] +test = ["pytest>=8.1.0", "pytest-timeout>=2.4.0", "pytest-cov>=5.0.0", "pyserial>=3.5", "mock-serial>=0.0.1 ; sys_platform != 'win32'"] +hilserl = ["transformers>=4.50.3", "gym-hil>=0.1.8", "protobuf>=5.29.3", "grpcio==1.71.0"] umi = ["imagecodecs>=2024.1.1"] video_benchmark = ["scikit-image>=0.23.2", "pandas>=2.2.2"] xarm = ["gym-xarm>=0.1.1 ; python_version < '4.0'"] @@ -103,11 +111,14 @@ requires-poetry = ">=2.1" [tool.ruff] line-length = 110 target-version = "py310" -exclude = ["tests/artifacts/**/*.safetensors"] +exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"] [tool.ruff.lint] select = ["E4", "E7", "E9", "F", "I", "N", "B", "C4", "SIM"] +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401", "F403"] + [tool.bandit] exclude_dirs = [ "tests", diff --git a/tests/artifacts/policies/save_policy_to_safetensors.py b/tests/artifacts/policies/save_policy_to_safetensors.py index 106f0dc04..785f296c7 100644 --- a/tests/artifacts/policies/save_policy_to_safetensors.py +++ b/tests/artifacts/policies/save_policy_to_safetensors.py @@ -32,7 +32,7 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict): train_cfg = TrainPipelineConfig( # TODO(rcadene, aliberts): remove dataset download dataset=DatasetConfig(repo_id=ds_repo_id, episodes=[0]), - policy=make_policy_config(policy_name, **policy_kwargs), + policy=make_policy_config(policy_name, push_to_hub=False, **policy_kwargs), ) train_cfg.validate() # Needed for auto-setting some parameters diff --git a/tests/conftest.py b/tests/conftest.py index 7eec94bf8..69dd3049b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,9 +19,7 @@ import traceback import pytest from serial import SerialException -from lerobot import available_cameras, available_motors, available_robots -from lerobot.common.robot_devices.robots.utils import make_robot -from tests.utils import DEVICE, make_camera, make_motors_bus +from tests.utils import DEVICE # Import fixture modules as plugins pytest_plugins = [ @@ -64,21 +62,6 @@ def _check_component_availability(component_type, available_components, make_com return False -@pytest.fixture -def is_robot_available(robot_type): - return _check_component_availability(robot_type, available_robots, make_robot) - - -@pytest.fixture -def is_camera_available(camera_type): - return _check_component_availability(camera_type, available_cameras, make_camera) - - -@pytest.fixture -def is_motor_available(motor_type): - return _check_component_availability(motor_type, available_motors, make_motors_bus) - - @pytest.fixture def patch_builtins_input(monkeypatch): def print_text(text=None): diff --git a/tests/optim/test_optimizers.py b/tests/optim/test_optimizers.py index 997e14fe9..630353fca 100644 --- a/tests/optim/test_optimizers.py +++ b/tests/optim/test_optimizers.py @@ -21,6 +21,7 @@ from lerobot.common.constants import ( from lerobot.common.optim.optimizers import ( AdamConfig, AdamWConfig, + MultiAdamConfig, SGDConfig, load_optimizer_state, save_optimizer_state, @@ -33,13 +34,21 @@ from lerobot.common.optim.optimizers import ( (AdamConfig, torch.optim.Adam), (AdamWConfig, torch.optim.AdamW), (SGDConfig, torch.optim.SGD), + (MultiAdamConfig, dict), ], ) def test_optimizer_build(config_cls, expected_class, model_params): config = config_cls() - optimizer = config.build(model_params) - assert isinstance(optimizer, expected_class) - assert optimizer.defaults["lr"] == config.lr + if config_cls == MultiAdamConfig: + params_dict = {"default": model_params} + optimizer = config.build(params_dict) + assert isinstance(optimizer, expected_class) + assert isinstance(optimizer["default"], torch.optim.Adam) + assert optimizer["default"].defaults["lr"] == config.lr + else: + optimizer = config.build(model_params) + assert isinstance(optimizer, expected_class) + assert optimizer.defaults["lr"] == config.lr def test_save_optimizer_state(optimizer, tmp_path): @@ -54,3 +63,180 @@ def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path): loaded_optimizer = load_optimizer_state(loaded_optimizer, tmp_path) torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict()) + + +@pytest.fixture +def base_params_dict(): + return { + "actor": [torch.nn.Parameter(torch.randn(10, 10))], + "critic": [torch.nn.Parameter(torch.randn(5, 5))], + "temperature": [torch.nn.Parameter(torch.randn(3, 3))], + } + + +@pytest.mark.parametrize( + "config_params, expected_values", + [ + # Test 1: Basic configuration with different learning rates + ( + { + "lr": 1e-3, + "weight_decay": 1e-4, + "optimizer_groups": { + "actor": {"lr": 1e-4}, + "critic": {"lr": 5e-4}, + "temperature": {"lr": 2e-3}, + }, + }, + { + "actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)}, + "critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999)}, + "temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999)}, + }, + ), + # Test 2: Different weight decays and beta values + ( + { + "lr": 1e-3, + "weight_decay": 1e-4, + "optimizer_groups": { + "actor": {"lr": 1e-4, "weight_decay": 1e-5}, + "critic": {"lr": 5e-4, "weight_decay": 1e-6}, + "temperature": {"lr": 2e-3, "betas": (0.95, 0.999)}, + }, + }, + { + "actor": {"lr": 1e-4, "weight_decay": 1e-5, "betas": (0.9, 0.999)}, + "critic": {"lr": 5e-4, "weight_decay": 1e-6, "betas": (0.9, 0.999)}, + "temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.95, 0.999)}, + }, + ), + # Test 3: Epsilon parameter customization + ( + { + "lr": 1e-3, + "weight_decay": 1e-4, + "optimizer_groups": { + "actor": {"lr": 1e-4, "eps": 1e-6}, + "critic": {"lr": 5e-4, "eps": 1e-7}, + "temperature": {"lr": 2e-3, "eps": 1e-8}, + }, + }, + { + "actor": {"lr": 1e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-6}, + "critic": {"lr": 5e-4, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-7}, + "temperature": {"lr": 2e-3, "weight_decay": 1e-4, "betas": (0.9, 0.999), "eps": 1e-8}, + }, + ), + ], +) +def test_multi_adam_configuration(base_params_dict, config_params, expected_values): + # Create config with the given parameters + config = MultiAdamConfig(**config_params) + optimizers = config.build(base_params_dict) + + # Verify optimizer count and keys + assert len(optimizers) == len(expected_values) + assert set(optimizers.keys()) == set(expected_values.keys()) + + # Check that all optimizers are Adam instances + for opt in optimizers.values(): + assert isinstance(opt, torch.optim.Adam) + + # Verify hyperparameters for each optimizer + for name, expected in expected_values.items(): + optimizer = optimizers[name] + for param, value in expected.items(): + assert optimizer.defaults[param] == value + + +@pytest.fixture +def multi_optimizers(base_params_dict): + config = MultiAdamConfig( + lr=1e-3, + optimizer_groups={ + "actor": {"lr": 1e-4}, + "critic": {"lr": 5e-4}, + "temperature": {"lr": 2e-3}, + }, + ) + return config.build(base_params_dict) + + +def test_save_multi_optimizer_state(multi_optimizers, tmp_path): + # Save optimizer states + save_optimizer_state(multi_optimizers, tmp_path) + + # Verify that directories were created for each optimizer + for name in multi_optimizers: + assert (tmp_path / name).is_dir() + assert (tmp_path / name / OPTIMIZER_STATE).is_file() + assert (tmp_path / name / OPTIMIZER_PARAM_GROUPS).is_file() + + +def test_save_and_load_multi_optimizer_state(base_params_dict, multi_optimizers, tmp_path): + # Option 1: Add a minimal backward pass to populate optimizer states + for name, params in base_params_dict.items(): + if name in multi_optimizers: + # Create a dummy loss and do backward + dummy_loss = params[0].sum() + dummy_loss.backward() + # Perform an optimization step + multi_optimizers[name].step() + # Zero gradients for next steps + multi_optimizers[name].zero_grad() + + # Save optimizer states + save_optimizer_state(multi_optimizers, tmp_path) + + # Create new optimizers with the same config + config = MultiAdamConfig( + lr=1e-3, + optimizer_groups={ + "actor": {"lr": 1e-4}, + "critic": {"lr": 5e-4}, + "temperature": {"lr": 2e-3}, + }, + ) + new_optimizers = config.build(base_params_dict) + + # Load optimizer states + loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path) + + # Verify state dictionaries match + for name in multi_optimizers: + torch.testing.assert_close(multi_optimizers[name].state_dict(), loaded_optimizers[name].state_dict()) + + +def test_save_and_load_empty_multi_optimizer_state(base_params_dict, tmp_path): + """Test saving and loading optimizer states even when the state is empty (no backward pass).""" + # Create config and build optimizers + config = MultiAdamConfig( + lr=1e-3, + optimizer_groups={ + "actor": {"lr": 1e-4}, + "critic": {"lr": 5e-4}, + "temperature": {"lr": 2e-3}, + }, + ) + optimizers = config.build(base_params_dict) + + # Save optimizer states without any backward pass (empty state) + save_optimizer_state(optimizers, tmp_path) + + # Create new optimizers with the same config + new_optimizers = config.build(base_params_dict) + + # Load optimizer states + loaded_optimizers = load_optimizer_state(new_optimizers, tmp_path) + + # Verify hyperparameters match even with empty state + for name, optimizer in optimizers.items(): + assert optimizer.defaults["lr"] == loaded_optimizers[name].defaults["lr"] + assert optimizer.defaults["weight_decay"] == loaded_optimizers[name].defaults["weight_decay"] + assert optimizer.defaults["betas"] == loaded_optimizers[name].defaults["betas"] + + # Verify state dictionaries match (they will be empty) + torch.testing.assert_close( + optimizer.state_dict()["param_groups"], loaded_optimizers[name].state_dict()["param_groups"] + )