From 1c17419224fca17930d998c20c54b3a6cbad8706 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Wed, 2 Jul 2025 17:26:34 +0200 Subject: [PATCH] Reverted back files that were changed during the rebase --- Makefile | 18 +- examples/4_train_policy_with_script.md | 42 +- examples/lekiwi/record.py | 3 +- .../policies/pi0fast/configuration_pi0fast.py | 136 --- .../policies/pi0fast/modeling_pi0fast.py | 982 ------------------ .../v21/convert_dataset_v20_to_v21.py | 2 +- src/lerobot/find_port.py | 65 ++ src/lerobot/policies/act/modeling_act.py | 769 ++++++++++++++ src/lerobot/policies/factory.py | 178 ++++ src/lerobot/policies/tdmpc/modeling_tdmpc.py | 834 +++++++++++++++ src/lerobot/robots/viperx/README.md | 182 ++++ src/lerobot/scripts/eval.py | 506 +++++++++ src/lerobot/utils/control_utils.py | 215 ++++ tests/configs/test_plugin_loading.py | 4 +- tests/datasets/test_image_transforms.py | 4 +- tests/envs/test_envs.py | 4 +- tests/fixtures/constants.py | 2 +- tests/optim/test_optimizers.py | 4 +- tests/policies/test_policies.py | 26 +- 19 files changed, 2803 insertions(+), 1173 deletions(-) delete mode 100644 lerobot/common/policies/pi0fast/configuration_pi0fast.py delete mode 100644 lerobot/common/policies/pi0fast/modeling_pi0fast.py diff --git a/Makefile b/Makefile index 9457dbe6e..ca1495fac 100644 --- a/Makefile +++ b/Makefile @@ -44,7 +44,7 @@ test-end-to-end: ${MAKE} DEVICE=$(DEVICE) test-smolvla-ete-eval test-act-ete-train: - python lerobot/scripts/train.py \ + python -m lerobot.scripts.train \ --policy.type=act \ --policy.dim_model=64 \ --policy.n_action_steps=20 \ @@ -68,12 +68,12 @@ test-act-ete-train: --output_dir=tests/outputs/act/ test-act-ete-train-resume: - python lerobot/scripts/train.py \ + python -m lerobot.scripts.train \ --config_path=tests/outputs/act/checkpoints/000002/pretrained_model/train_config.json \ --resume=true test-act-ete-eval: - python lerobot/scripts/eval.py \ + python -m lerobot.scripts.eval \ --policy.path=tests/outputs/act/checkpoints/000004/pretrained_model \ --policy.device=$(DEVICE) \ --env.type=aloha \ @@ -82,7 +82,7 @@ test-act-ete-eval: --eval.batch_size=1 test-diffusion-ete-train: - python lerobot/scripts/train.py \ + python -m lerobot.scripts.train \ --policy.type=diffusion \ --policy.down_dims='[64,128,256]' \ --policy.diffusion_step_embed_dim=32 \ @@ -106,7 +106,7 @@ test-diffusion-ete-train: --output_dir=tests/outputs/diffusion/ test-diffusion-ete-eval: - python lerobot/scripts/eval.py \ + python -m lerobot.scripts.eval \ --policy.path=tests/outputs/diffusion/checkpoints/000002/pretrained_model \ --policy.device=$(DEVICE) \ --env.type=pusht \ @@ -115,7 +115,7 @@ test-diffusion-ete-eval: --eval.batch_size=1 test-tdmpc-ete-train: - python lerobot/scripts/train.py \ + python -m lerobot.scripts.train \ --policy.type=tdmpc \ --policy.device=$(DEVICE) \ --policy.push_to_hub=false \ @@ -137,7 +137,7 @@ test-tdmpc-ete-train: --output_dir=tests/outputs/tdmpc/ test-tdmpc-ete-eval: - python lerobot/scripts/eval.py \ + python -m lerobot.scripts.eval \ --policy.path=tests/outputs/tdmpc/checkpoints/000002/pretrained_model \ --policy.device=$(DEVICE) \ --env.type=xarm \ @@ -148,7 +148,7 @@ test-tdmpc-ete-eval: test-smolvla-ete-train: - python lerobot/scripts/train.py \ + python -m lerobot.scripts.train \ --policy.type=smolvla \ --policy.n_action_steps=20 \ --policy.chunk_size=20 \ @@ -171,7 +171,7 @@ test-smolvla-ete-train: --output_dir=tests/outputs/smolvla/ test-smolvla-ete-eval: - python lerobot/scripts/eval.py \ + python -m lerobot.scripts.eval \ --policy.path=tests/outputs/smolvla/checkpoints/000004/pretrained_model \ --policy.device=$(DEVICE) \ --env.type=aloha \ diff --git a/examples/4_train_policy_with_script.md b/examples/4_train_policy_with_script.md index ff8913016..f17411b75 100644 --- a/examples/4_train_policy_with_script.md +++ b/examples/4_train_policy_with_script.md @@ -4,7 +4,7 @@ This tutorial will explain the training script, how to use it, and particularly ## The training script -LeRobot offers a training script at [`lerobot/scripts/train.py`](../../lerobot/scripts/train.py). At a high level it does the following: +LeRobot offers a training script at [`lerobot/scripts/train.py`](../src/lerobot/scripts/train.py). At a high level it does the following: - Initialize/load a configuration for the following steps using. - Instantiates a dataset. @@ -21,7 +21,7 @@ In the training script, the main function `train` expects a `TrainPipelineConfig def train(cfg: TrainPipelineConfig): ``` -You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../../lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option) +You can inspect the `TrainPipelineConfig` defined in [`lerobot/configs/train.py`](../src/lerobot/configs/train.py) (which is heavily commented and meant to be a reference to understand any option) When running the script, inputs for the command line are parsed thanks to the `@parser.wrap()` decorator and an instance of this class is automatically generated. Under the hood, this is done with [Draccus](https://github.com/dlwh/draccus) which is a tool dedicated to this purpose. If you're familiar with Hydra, Draccus can similarly load configurations from config files (.json, .yaml) and also override their values through command line inputs. Unlike Hydra, these configurations are pre-defined in the code through dataclasses rather than being defined entirely in config files. This allows for more rigorous serialization/deserialization, typing, and to manipulate configuration as objects directly in the code and not as dictionaries or namespaces (which enables nice features in an IDE such as autocomplete, jump-to-def, etc.) @@ -50,9 +50,9 @@ By default, every field takes its default value specified in the dataclass. If a ## Specifying values from the CLI -Let's say that we want to train [Diffusion Policy](../../lerobot/common/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this: +Let's say that we want to train [Diffusion Policy](../src/lerobot/policies/diffusion) on the [pusht](https://huggingface.co/datasets/lerobot/pusht) dataset, using the [gym_pusht](https://github.com/huggingface/gym-pusht) environment for evaluation. The command to do so would look like this: ```bash -python lerobot/scripts/train.py \ +python -m lerobot.scripts.train \ --dataset.repo_id=lerobot/pusht \ --policy.type=diffusion \ --env.type=pusht @@ -60,12 +60,12 @@ python lerobot/scripts/train.py \ Let's break this down: - To specify the dataset, we just need to specify its `repo_id` on the hub which is the only required argument in the `DatasetConfig`. The rest of the fields have default values and in this case we are fine with those so we can just add the option `--dataset.repo_id=lerobot/pusht`. -- To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/common/policies](../../lerobot/common/policies) -- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/common/envs/configs.py`](../../lerobot/common/envs/configs.py) +- To specify the policy, we can just select diffusion policy using `--policy` appended with `.type`. Here, `.type` is a special argument which allows us to select config classes inheriting from `draccus.ChoiceRegistry` and that have been decorated with the `register_subclass()` method. To have a better explanation of this feature, have a look at this [Draccus demo](https://github.com/dlwh/draccus?tab=readme-ov-file#more-flexible-configuration-with-choice-types). In our code, we use this mechanism mainly to select policies, environments, robots, and some other components like optimizers. The policies available to select are located in [lerobot/policies](../src/lerobot/policies) +- Similarly, we select the environment with `--env.type=pusht`. The different environment configs are available in [`lerobot/envs/configs.py`](../src/lerobot/envs/configs.py) -Let's see another example. Let's say you've been training [ACT](../../lerobot/common/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with: +Let's see another example. Let's say you've been training [ACT](../src/lerobot/policies/act) on [lerobot/aloha_sim_insertion_human](https://huggingface.co/datasets/lerobot/aloha_sim_insertion_human) using the [gym-aloha](https://github.com/huggingface/gym-aloha) environment for evaluation with: ```bash -python lerobot/scripts/train.py \ +python -m lerobot.scripts.train \ --policy.type=act \ --dataset.repo_id=lerobot/aloha_sim_insertion_human \ --env.type=aloha \ @@ -74,9 +74,9 @@ python lerobot/scripts/train.py \ > Notice we added `--output_dir` to explicitly tell where to write outputs from this run (checkpoints, training state, configs etc.). This is not mandatory and if you don't specify it, a default directory will be created from the current date and time, env.type and policy.type. This will typically look like `outputs/train/2025-01-24/16-10-05_aloha_act`. We now want to train a different policy for aloha on another task. We'll change the dataset and use [lerobot/aloha_sim_transfer_cube_human](https://huggingface.co/datasets/lerobot/aloha_sim_transfer_cube_human) instead. Of course, we also need to change the task of the environment as well to match this other task. -Looking at the [`AlohaEnv`](../../lerobot/common/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using: +Looking at the [`AlohaEnv`](../src/lerobot/envs/configs.py) config, the task is `"AlohaInsertion-v0"` by default, which corresponds to the task we trained on in the command above. The [gym-aloha](https://github.com/huggingface/gym-aloha?tab=readme-ov-file#description) environment also has the `AlohaTransferCube-v0` task which corresponds to this other task we want to train on. Putting this together, we can train this new policy on this different task using: ```bash -python lerobot/scripts/train.py \ +python -m lerobot.scripts.train \ --policy.type=act \ --dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \ --env.type=aloha \ @@ -111,7 +111,7 @@ Now, let's assume that we want to reproduce the run just above. That run has pro We can then simply load the config values from this file using: ```bash -python lerobot/scripts/train.py \ +python -m lerobot.scripts.train \ --config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \ --output_dir=outputs/train/act_aloha_transfer_2 ``` @@ -119,7 +119,7 @@ python lerobot/scripts/train.py \ Similarly to Hydra, we can still override some parameters in the CLI if we want to, e.g.: ```bash -python lerobot/scripts/train.py \ +python -m lerobot.scripts.train \ --config_path=outputs/train/act_aloha_transfer/checkpoints/last/pretrained_model/ \ --output_dir=outputs/train/act_aloha_transfer_2 --policy.n_action_steps=80 @@ -128,7 +128,7 @@ python lerobot/scripts/train.py \ `--config_path` can also accept the repo_id of a repo on the hub that contains a `train_config.json` file, e.g. running: ```bash -python lerobot/scripts/train.py --config_path=lerobot/diffusion_pusht +python -m lerobot.scripts.train --config_path=lerobot/diffusion_pusht ``` will start a training run with the same configuration used for training [lerobot/diffusion_pusht](https://huggingface.co/lerobot/diffusion_pusht) @@ -139,7 +139,7 @@ Being able to resume a training run is important in case it crashed or aborted f Let's reuse the command from the previous run and add a few more options: ```bash -python lerobot/scripts/train.py \ +python -m lerobot.scripts.train \ --policy.type=act \ --dataset.repo_id=lerobot/aloha_sim_transfer_cube_human \ --env.type=aloha \ @@ -155,7 +155,7 @@ INFO 2025-01-24 16:10:56 ts/train.py:263 Checkpoint policy after step 100 ``` Now let's simulate a crash by killing the process (hit `ctrl`+`c`). We can then simply resume this run from the last checkpoint available with: ```bash -python lerobot/scripts/train.py \ +python -m lerobot.scripts.train \ --config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \ --resume=true ``` @@ -164,7 +164,7 @@ You should see from the logging that your training picks up from where it left o Another reason for which you might want to resume a run is simply to extend training and add more training steps. The number of training steps is set by the option `--steps`, which is 100 000 by default. You could double the number of steps of the previous run with: ```bash -python lerobot/scripts/train.py \ +python -m lerobot.scripts.train \ --config_path=outputs/train/run_resumption/checkpoints/last/pretrained_model/ \ --resume=true \ --steps=200000 @@ -195,7 +195,7 @@ In addition to the features currently in Draccus, we've added a special `.path` For example, we could fine-tune a [policy pre-trained on the aloha transfer task](https://huggingface.co/lerobot/act_aloha_sim_transfer_cube_human) on the aloha insertion task. We can achieve this with: ```bash -python lerobot/scripts/train.py \ +python -m lerobot.scripts.train \ --policy.path=lerobot/act_aloha_sim_transfer_cube_human \ --dataset.repo_id=lerobot/aloha_sim_insertion_human \ --env.type=aloha \ @@ -236,7 +236,7 @@ We'll summarize here the main use cases to remember from this tutorial. #### Train a policy from scratch – CLI ```bash -python lerobot/scripts/train.py \ +python -m lerobot.scripts.train \ --policy.type=act \ # <- select 'act' policy --env.type=pusht \ # <- select 'pusht' environment --dataset.repo_id=lerobot/pusht # <- train on this dataset @@ -244,14 +244,14 @@ python lerobot/scripts/train.py \ #### Train a policy from scratch - config file + CLI ```bash -python lerobot/scripts/train.py \ +python -m lerobot.scripts.train \ --config_path=path/to/pretrained_model \ # <- can also be a repo_id --policy.n_action_steps=80 # <- you may still override values ``` #### Resume/continue a training run ```bash -python lerobot/scripts/train.py \ +python -m lerobot.scripts.train \ --config_path=checkpoint/pretrained_model/ \ --resume=true \ --steps=200000 # <- you can change some training parameters @@ -259,7 +259,7 @@ python lerobot/scripts/train.py \ #### Fine-tuning ```bash -python lerobot/scripts/train.py \ +python -m lerobot.scripts.train \ --policy.path=lerobot/act_aloha_sim_transfer_cube_human \ # <- can also be a local path to a checkpoint --dataset.repo_id=lerobot/aloha_sim_insertion_human \ --env.type=aloha \ diff --git a/examples/lekiwi/record.py b/examples/lekiwi/record.py index e6b774f19..68d6d1b01 100644 --- a/examples/lekiwi/record.py +++ b/examples/lekiwi/record.py @@ -51,8 +51,7 @@ while i < NB_CYCLES_CLIENT_CONNECTION: action_sent = robot.send_action(action) observation = robot.get_observation() - task = "Dummy Example Task Dataset" - frame = {**action_sent, **observation, "task": task} + frame = {**action_sent, **observation, "task": "Dummy Example Task Dataset"} dataset.add_frame(frame) i += 1 diff --git a/lerobot/common/policies/pi0fast/configuration_pi0fast.py b/lerobot/common/policies/pi0fast/configuration_pi0fast.py deleted file mode 100644 index 29c856e06..000000000 --- a/lerobot/common/policies/pi0fast/configuration_pi0fast.py +++ /dev/null @@ -1,136 +0,0 @@ -from dataclasses import dataclass, field - -from lerobot.common.optim.optimizers import AdamWConfig -from lerobot.common.optim.schedulers import ( - CosineDecayWithWarmupSchedulerConfig, -) -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature - - -@PreTrainedConfig.register_subclass("pi0fast") -@dataclass -class PI0FASTConfig(PreTrainedConfig): - # Input / output structure. - n_obs_steps: int = 1 - chunk_size: int = 10 - n_action_steps: int = 5 - - normalization_mapping: dict[str, NormalizationMode] = field( - default_factory=lambda: { - "VISUAL": NormalizationMode.IDENTITY, - "STATE": NormalizationMode.MEAN_STD, - "ACTION": NormalizationMode.MEAN_STD, - } - ) - - # Shorter state and action vectors will be padded - max_state_dim: int = 32 # 32 - max_action_dim: int = 32 # 32 - - # Image preprocessing - resize_imgs_with_padding: tuple[int, int] = (224, 224) - interpolate_like_pi: bool = False - - # Add empty images. Used by pi0_aloha_sim which adds the empty - # left and right wrist cameras in addition to the top camera. - empty_cameras: int = 0 - - # Converts the joint and gripper values from the standard Aloha space to - # the space used by the pi internal runtime which was used to train the base model. - adapt_to_pi_aloha: bool = False - - # Converts joint dimensions to deltas with respect to the current state before passing to the model. - # Gripper dimensions will remain in absolute values. - use_delta_joint_actions_aloha: bool = False - - # Tokenizer - tokenizer_max_length: int = 48 - - # Projector - proj_width: int = 1024 - - # Decoding - max_decoding_steps: int = 256 - fast_skip_tokens: int = 128 # Skip last 128 tokens in PaliGemma vocab since they are special tokens - max_input_seq_len: int = 256 # 512 - - # Utils - use_cache: bool = True - - # Frozen parameters - freeze_vision_encoder: bool = True - freeze_lm_head: bool = True - - # Training presets - optimizer_lr: float = 1e-4 - optimizer_betas: tuple[float, float] = (0.9, 0.95) - optimizer_eps: float = 1e-8 - optimizer_weight_decay: float = 1e-5 - - scheduler_warmup_steps: int = 1_000 - scheduler_decay_steps: int = 30_000 - scheduler_decay_lr: float = 2.5e-6 - - checkpoint_path: str = None - - padding_side: str = "right" - - precision: str = "bfloat16" - grad_clip_norm: float = 1 - - # Allows padding/truncation of generated action tokens during detokenization to ensure decoding. - # In the original version, tensors of 0s were generated if shapes didn't match for stable decoding. - relaxed_action_decoding: bool = True - - def __post_init__(self): - super().__post_init__() - - """Input validation (not exhaustive).""" - if self.n_action_steps > self.chunk_size: - raise ValueError( - f"The chunk size is the upper bound for the number of action steps per model invocation. Got " - f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." - ) - if self.n_obs_steps != 1: - raise ValueError( - f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" - ) - - def validate_features(self) -> None: - for i in range(self.empty_cameras): - key = f"observation.images.empty_camera_{i}" - empty_camera = PolicyFeature( - type=FeatureType.VISUAL, - shape=(3, 480, 640), - ) - self.input_features[key] = empty_camera - - def get_optimizer_preset(self) -> AdamWConfig: - return AdamWConfig( - lr=self.optimizer_lr, - betas=self.optimizer_betas, - eps=self.optimizer_eps, - weight_decay=self.optimizer_weight_decay, - grad_clip_norm=self.grad_clip_norm, - ) - - def get_scheduler_preset(self): - return CosineDecayWithWarmupSchedulerConfig( - peak_lr=self.optimizer_lr, - decay_lr=self.scheduler_decay_lr, - num_warmup_steps=self.scheduler_warmup_steps, - num_decay_steps=self.scheduler_decay_steps, - ) - - @property - def observation_delta_indices(self) -> None: - return None - - @property - def action_delta_indices(self) -> list: - return list(range(self.chunk_size)) - - @property - def reward_delta_indices(self) -> None: - return None diff --git a/lerobot/common/policies/pi0fast/modeling_pi0fast.py b/lerobot/common/policies/pi0fast/modeling_pi0fast.py deleted file mode 100644 index dbf5266b1..000000000 --- a/lerobot/common/policies/pi0fast/modeling_pi0fast.py +++ /dev/null @@ -1,982 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -π0+FAST: Efficient Action Tokenization for Vision-Language-Action Models - -[Paper](https://huggingface.co/papers/2501.09747) -[Jax code](https://github.com/Physical-Intelligence/openpi) - -Designed by Physical Intelligence. Ported from Jax by Hugging Face. - -Example of finetuning the pi0+FAST pretrained model (`pi0_fast_base` in `openpi`): -```bash -python lerobot/scripts/train.py \ ---policy.path=lerobot/pi0fast_base \ ---dataset.repo_id=danaaubakirova/koch_test -``` - -Example of training the pi0+FAST neural network with from scratch: -```bash -python lerobot/scripts/train.py \ ---policy.type=pi0fast \ ---dataset.repo_id=danaaubakirova/koch_test -``` - -Example of using the pi0 pretrained model outside LeRobot training framework: -```python -policy = PI0FASTPolicy.from_pretrained("lerobot/pi0fast_base") -``` - -""" - -from collections import deque -from functools import partial - -import numpy as np -import torch -import torch.nn.functional as F # noqa: N812 -from PIL import Image -from scipy.fft import idct -from torch import Tensor, nn -from transformers import AutoProcessor, AutoTokenizer, PaliGemmaForConditionalGeneration -from transformers.cache_utils import HybridCache, StaticCache -from transformers.models.auto import CONFIG_MAPPING - -from lerobot.common.constants import ACTION, OBS_STATE -from lerobot.common.policies.normalize import Normalize, Unnormalize -from lerobot.common.policies.pi0fast.configuration_pi0fast import PI0FASTConfig -from lerobot.common.policies.pretrained import PreTrainedPolicy - -PRECISION = { - "float16": torch.float16, - "float32": torch.float32, - "bfloat16": torch.bfloat16, -} - - -def normalize(x, min_val, max_val): - return (x - min_val) / (max_val - min_val) - - -def unnormalize(x, min_val, max_val): - return x * (max_val - min_val) + min_val - - -def safe_arcsin(value): - # This ensures that the input stays within - # [−1,1] to avoid invalid values for arcsin - return torch.arcsin(torch.clamp(value, -1.0, 1.0)) - - -def aloha_gripper_to_angular(value): - # Aloha transforms the gripper positions into a linear space. The following code - # reverses this transformation to be consistent with pi0 which is pretrained in - # angular space. - # - # These values are coming from the Aloha code: - # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED - value = unnormalize(value, min_val=0.01844, max_val=0.05800) - - # This is the inverse of the angular to linear transformation inside the Interbotix code. - def linear_to_radian(linear_position, arm_length, horn_radius): - value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) - return safe_arcsin(value) - - # The constants are taken from the Interbotix code. - value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022) - - # Normalize to [0, 1]. - # The values 0.4 and 1.5 were measured on an actual Trossen robot. - return normalize(value, min_val=0.4, max_val=1.5) - - -def aloha_gripper_from_angular(value): - # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha. - # Note that the units are still angular but the range is different. - - # The values 0.4 and 1.5 were measured on an actual Trossen robot. - value = unnormalize(value, min_val=0.4, max_val=1.5) - - # These values are coming from the Aloha code: - # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE - return normalize(value, min_val=-0.6213, max_val=1.4910) - - -def aloha_gripper_from_angular_inv(value): - # Directly inverts the gripper_from_angular function. - value = unnormalize(value, min_val=-0.6213, max_val=1.4910) - return normalize(value, min_val=0.4, max_val=1.5) - - -class PI0FASTPolicy(PreTrainedPolicy): - """Wrapper class around PI0FAST tokenizer and model to train and run inference within LeRobot.""" - - config_class = PI0FASTConfig - name = "pi0fast" - - def __init__( - self, - config: PI0FASTConfig, - dataset_stats: dict[str, dict[str, Tensor]] | None = None, - ): - """ - Args: - config: Policy configuration class instance or None, in which case the default instantiation of - the configuration class is used. - dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected - that they will be passed with a call to `load_state_dict` before the policy is used. - """ - - super().__init__(config) - config.validate_features() - self.config = config - - self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) - self.normalize_targets = Normalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - self.unnormalize_outputs = Unnormalize( - config.output_features, config.normalization_mapping, dataset_stats - ) - - self.language_tokenizer = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224") - self.model = PI0FAST(config) - - self.reset() - - def reset(self): - """This should be called whenever the environment is reset.""" - self._action_queue = deque([], maxlen=self.config.n_action_steps) - - def get_optim_params(self) -> dict: - return self.parameters() - - def _pi_aloha_decode_state(self, state): - # Flip the joints. - for motor_idx in [1, 2, 8, 9]: - state[:, motor_idx] *= -1 - # Reverse the gripper transformation that is being applied by the Aloha runtime. - for motor_idx in [6, 13]: - state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx]) - return state - - def _pi_aloha_encode_actions(self, actions): - # Flip the joints. - for motor_idx in [1, 2, 8, 9]: - actions[:, :, motor_idx] *= -1 - # Reverse the gripper transformation that is being applied by the Aloha runtime. - for motor_idx in [6, 13]: - actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx]) - return actions - - def _pi_aloha_encode_actions_inv(self, actions): - # Flip the joints again. - for motor_idx in [1, 2, 8, 9]: - actions[:, :, motor_idx] *= -1 - # Reverse the gripper transformation that is being applied by the Aloha runtime. - for motor_idx in [6, 13]: - actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) - return actions - - @torch.no_grad - def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: - """Predict a chunk of actions given environment observations.""" - raise NotImplementedError("Currently not implemented for PI0FAST") - - @torch.no_grad - def select_action(self, batch: dict[str, Tensor]) -> Tensor: - """Select a single action given environment observations. - - This method wraps `select_actions` in order to return one action at a time for execution in the - environment. It works by managing the actions in a queue and only calling `select_actions` when the - queue is empty. - """ - self.eval() - - if self.config.adapt_to_pi_aloha: - batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - - batch = self.normalize_inputs(batch) - - # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by - # querying the policy. - if len(self._action_queue) == 0: - actions = self.model.generate_actions(batch) - - actions = actions[:, : self.config.n_action_steps] - - original_action_dim = self.config.action_feature.shape[ - 0 - ] # self.config.max_action_dim # self.config.action_feature.shape[0] - actions = actions[:, :, :original_action_dim] - - actions = self.unnormalize_outputs({"action": actions})["action"] - - if self.config.adapt_to_pi_aloha: - actions = self._pi_aloha_encode_actions(actions) - - # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue - # effectively has shape (n_action_steps, batch_size, *), hence the transpose. - self._action_queue.extend(actions.transpose(0, 1)) - return self._action_queue.popleft() - - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: - if self.config.adapt_to_pi_aloha: - batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) - batch = self.normalize_inputs(batch) - batch = self.normalize_targets(batch) - loss_dict = self.model.forward(batch) - return loss_dict["loss"], loss_dict - - -def block_causal_update_causal_mask( - attention_mask, - token_type_ids=None, - past_key_values=None, - cache_position=None, - input_tensor=None, - attn_implementation: str = "eager", - dtype: torch.dtype = "float32", -): - """ - Update the causal mask during training and generation. It can be customized to different attention masks. - """ - if attn_implementation == "flash_attention_2": - if attention_mask is not None and 0.0 in attention_mask: - return attention_mask - return None - using_static_cache = isinstance(past_key_values, StaticCache) - min_dtype = torch.finfo(dtype).min - - if input_tensor is None: - input_tensor = attention_mask - - inputs_lead_dim, sequence_length = input_tensor.shape[:2] - - if using_static_cache or isinstance(past_key_values, HybridCache): - target_length = past_key_values.get_max_cache_shape() - else: - target_length = ( - attention_mask.shape[-1] - if isinstance(attention_mask, torch.Tensor) - else cache_position[0] + sequence_length + 1 - ) - - # Handle precomputed attention masks - if attention_mask is not None and attention_mask.dim() == 4: - return attention_mask - - # Causal mask initialization - causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device - ) - - # Standard causal masking (triu ensures tokens can only attend to past) - if sequence_length != 1: - causal_mask = torch.triu(causal_mask, diagonal=1) - - # Apply block causal mask - if token_type_ids is not None: - token_type_ids = token_type_ids.to(causal_mask.device).bool() - cumsum = torch.cumsum(token_type_ids, dim=1) - block_causal_mask = cumsum[:, None, :] <= cumsum[:, :, None] - - # Combine causal_mask with block-wise attention mask - causal_mask = torch.where(block_causal_mask, 0.0, causal_mask) - causal_mask = causal_mask[:, None, :, :] - else: - # Apply past cache position constraint - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) - else: - # Apply past cache position constraint - causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape( - -1, 1 - ) - causal_mask = causal_mask[None, None, :, :].expand(inputs_lead_dim, 1, -1, -1) - - if attention_mask is not None: - causal_mask = causal_mask.clone() # Copy to contiguous memory for in-place edits - mask_length = attention_mask.shape[-1] - - # Apply padding mask - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( - causal_mask.device - ) - padding_mask = padding_mask == 0 - causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( - padding_mask, min_dtype - ) - - return causal_mask - - -def prepare_inputs_for_generation( - # self, - input_ids, - past_key_values=None, - inputs_embeds=None, - cache_position=None, - position_ids=None, - pixel_values=None, - attention_mask=None, - token_type_ids=None, - use_cache=True, - num_logits_to_keep=None, - labels=None, - self=None, - **kwargs, -): - # create block causal attention - if cache_position[0] > 0 and input_ids.shape[1] > 0: - input_tensor = input_ids[:, -1:] - new_positions = ( - torch.ones( - (position_ids.shape[0], input_ids.shape[1]), - dtype=position_ids.dtype, - device=position_ids.device, - ).cumsum(-1) - + position_ids[:, -1:] - ) - position_ids = torch.cat([position_ids, new_positions], dim=-1) - else: - input_tensor = inputs_embeds - attention_mask = block_causal_update_causal_mask( - attention_mask=attention_mask, - past_key_values=past_key_values, - cache_position=cache_position, - input_tensor=input_tensor, - token_type_ids=token_type_ids, - dtype=self.dtype, - attn_implementation=self.config.text_config._attn_implementation, - ) - # Overwritten -- custom `position_ids` and `pixel_values` handling - model_inputs = self.language_model.prepare_inputs_for_generation( - input_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - position_ids=position_ids, - cache_position=cache_position, - use_cache=use_cache, - num_logits_to_keep=num_logits_to_keep, - token_type_ids=token_type_ids, - **kwargs, - ) - - # Position_ids in Paligemma are 1-indexed - if model_inputs.get("position_ids") is not None: - model_inputs["position_ids"] += 1 - # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore - # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always - if cache_position[0] == 0: - model_inputs["pixel_values"] = pixel_values - is_training = token_type_ids is not None and labels is not None - if cache_position[0] == 0 and isinstance(past_key_values, HybridCache): - input_tensor = inputs_embeds if inputs_embeds is not None else input_ids - causal_mask = self._update_causal_mask( - attention_mask, token_type_ids, past_key_values, cache_position, input_tensor, is_training - ) - model_inputs["attention_mask"] = causal_mask - - return model_inputs - - -class PI0FAST(nn.Module): - def __init__(self, config: PI0FASTConfig): - super().__init__() - self.config = config - - # TODO: move tokenizers in Policy - fast_tokenizer_path = "physical-intelligence/fast" - pi0_paligemma_path = "google/paligemma-3b-pt-224" - self.paligemma_tokenizer = AutoTokenizer.from_pretrained(pi0_paligemma_path) - self.processor = AutoProcessor.from_pretrained(pi0_paligemma_path) - self.fast_tokenizer = AutoProcessor.from_pretrained(fast_tokenizer_path, trust_remote_code=True) - self.fast_skip_tokens = self.config.fast_skip_tokens - self.max_input_seq_len = self.config.max_input_seq_len - self.action_horizon = self.config.chunk_size - self.action_dim = self.config.action_feature.shape[ - 0 - ] # self.config.max_action_dim # self.config.action_feature.shape[0] - precision = config.precision - torch_precision = PRECISION.get(precision, torch.float32) - self.pad_token_id = ( - self.paligemma_tokenizer.pad_token_id - if hasattr(self.paligemma_tokenizer, "pad_token_id") - else self.paligemma_tokenizer.eos_token_id - ) - - paligemma_config = CONFIG_MAPPING["paligemma"]( - transformers_version="4.48.1", - _vocab_size=257152, - bos_token_id=2, - eos_token_id=1, - hidden_size=2048, - image_token_index=257152, - model_type="paligemma", - pad_token_id=0, - projection_dim=2048, - text_config={ - "hidden_activation": "gelu_pytorch_tanh", - "hidden_size": 2048, - "intermediate_size": 16384, - "model_type": "gemma", - "num_attention_heads": 8, - "num_hidden_layers": 18, - "num_image_tokens": 256, - "num_key_value_heads": 1, - "torch_dtype": precision, - "vocab_size": 257152, - "_attn_implementation": "eager", - }, - vision_config={ - "hidden_size": 1152, - "intermediate_size": 4304, - "model_type": "siglip_vision_model", - "num_attention_heads": 16, - "num_hidden_layers": 27, - "num_image_tokens": 256, - "patch_size": 14, - "projection_dim": 2048, - "projector_hidden_act": "gelu_pytorch_tanh", - "torch_dtype": precision, - "vision_use_head": False, - }, - ) - self.pi0_paligemma = PaliGemmaForConditionalGeneration(config=paligemma_config) - - self.pi0_paligemma.prepare_inputs_for_generation = partial( - prepare_inputs_for_generation, self=self.pi0_paligemma - ) - # change important stuff in bf16 - params_to_change_dtype = [ - "language_model", - "vision_tower", - "multi_modal", - ] - for name, param in self.pi0_paligemma.named_parameters(): - if any(selector in name for selector in params_to_change_dtype): - param.data = param.data.to(dtype=torch_precision) - self.set_requires_grad() - self.image_keys = self.config.image_features.keys() - self.ignore_index = self.pi0_paligemma.config.ignore_index - self.padding_side = self.config.padding_side - - def set_requires_grad(self): - if self.config.freeze_vision_encoder: - self.pi0_paligemma.vision_tower.eval() - for params in self.pi0_paligemma.vision_tower.parameters(): - params.requires_grad = False - # To avoid unused params issue with distributed training - if self.config.freeze_lm_head: - for name, params in self.pi0_paligemma.named_parameters(): - if "embed_tokens" in name: # lm heads and embedding layer are tied - params.requires_grad = False - - def embed_tokens(self, tokens: torch.Tensor): - return self.pi0_paligemma.language_model.model.embed_tokens(tokens) - - def prepare_inputs_for_generation(self, *args, **kwargs): - return self.pi0_paligemma.prepare_inputs_for_generation(*args, **kwargs) - - def prepare_images(self, batch): - """Preprocess LeRobot batch into Pi0 inputs""" - images = [] - img_masks = [] - present_img_keys = [key for key in self.image_keys if key in batch] - if len(present_img_keys) == 0: - raise ValueError( - f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})" - ) - - # Preprocess image features present in the batch - num_empty_cameras = 0 - for key in self.image_keys: - if key in present_img_keys: - img = batch[key] - - if self.config.resize_imgs_with_padding is not None: - img = resize_with_pad( - img, - *self.config.resize_imgs_with_padding, - pad_value=0, - interpolate_like_pi=self.config.interpolate_like_pi, - ) - - # Normalize from range [0,1] to [-1,1] as expected by siglip - img = img * 2.0 - 1.0 - - bsize = img.shape[0] - device = img.device - mask = torch.ones(bsize, dtype=torch.bool, device=device) - else: - if num_empty_cameras >= self.config.empty_cameras: - continue - img = torch.ones_like(img) * -1 - bsize = img.shape[0] - device = img.device - mask = torch.ones(bsize, dtype=torch.bool, device=device) - num_empty_cameras += 1 - - images.append(img) - img_masks.append(mask) - return images, img_masks - - def normalize_actions(self, actions: torch.Tensor) -> torch.Tensor: - mins = actions.amin(dim=(1, 2), keepdim=True) # [0] - maxs = actions.amax(dim=(1, 2), keepdim=True) # [0] - return 2 * (actions - mins) / (maxs - mins + 1e-8) - 1 - - def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor: - out = self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens - return out - - def fast_tokenizer_wrapper(self, actions_norm): - """ - A wrapper for self.fast_tokenizer that ensures batch processing, - conversion to PyTorch tensors, and returns a dictionary without padding. - """ - batch_tokens = self.fast_tokenizer(actions_norm) - fast_out = self.processor.tokenizer.pad({"input_ids": batch_tokens}, return_tensors="pt") - - return fast_out - - def create_token_type_ids(self, padded_mask: torch.Tensor, prefix_len: int) -> torch.Tensor: - token_type_ids = torch.zeros_like(padded_mask, dtype=torch.bool) - # Compute cumulative sum mask - cumsum_mask = (padded_mask != 0).cumsum(dim=1) - # Suffix block (everything after prefix_len) - suffix_mask = cumsum_mask > prefix_len - token_type_ids = suffix_mask - return token_type_ids - - def create_input_tokens(self, state, lang_text, actions=None): - bsize = state.shape[0] - device = state.device - bins = torch.linspace(-1, 1, 256 + 1, device=device)[:-1] - discretized = torch.bucketize(state, bins) - 1 - discretized = discretized[:, :32] - - prefix_texts = [] - state_text = [] - for txt, disc in zip(lang_text, discretized, strict=False): - cleaned = txt.lower().strip().replace("_", " ") - state_str = " ".join(str(val.item()) for val in disc) - prefix_texts.append(f"Task: {cleaned}, State: {state_str};\n") - state_text.append(f"State: {state_str};\n") - - prefix_out = self.paligemma_tokenizer( - prefix_texts, add_special_tokens=True, return_tensors="pt", padding="longest", truncation=False - ) - prefix_ids = prefix_out["input_ids"].to(device) - prefix_mask = prefix_out["attention_mask"].to(device) - prefix_lens = prefix_mask.sum(dim=1)[:, None].cpu() - - if actions is not None: - actions_norm = self.normalize_actions(actions) - actions_pad = F.pad( - actions_norm, (0, max(0, self.config.max_action_dim - actions_norm.shape[2])), value=0 - )[:, :, : self.config.max_action_dim] - fast_out = self.fast_tokenizer_wrapper( - actions_pad.cpu(), - ) - act_ids = fast_out["input_ids"] - act_mask = fast_out["attention_mask"].to(device) - - act_ids = self._act_tokens_to_paligemma_tokens(act_ids).to(device) - # Replace action with 0 to pad tokens - act_ids = torch.where( - act_ids == self.paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens, - self.pad_token_id, - act_ids, - ) - - eos_token = torch.tensor( - [self.paligemma_tokenizer.eos_token_id], dtype=torch.long, device=device - ).expand(bsize, -1) - eos_mask = torch.tensor([1], dtype=torch.long, device=device).expand(bsize, -1) - bos = self.paligemma_tokenizer("Action: ", add_special_tokens=False, return_tensors="pt") - bos_token = bos["input_ids"].expand(act_ids.shape[0], -1).to(device) - bos_mask = bos["attention_mask"].expand(act_ids.shape[0], -1).to(device) - act_ids = torch.cat([bos_token, act_ids, eos_token], dim=1) - act_mask = torch.cat([bos_mask, act_mask, eos_mask], dim=1) - act_mask = act_mask.to(device) - else: - act_ids = torch.empty(bsize, self.pad_token_id, dtype=torch.long, device=device) - act_mask = torch.empty(bsize, 0, dtype=torch.long, device=device) - final_ids = torch.cat([prefix_ids, act_ids], dim=1) - - final_mask = torch.cat([prefix_mask, act_mask], dim=1) - batch_inputs = {"input_ids": final_ids.tolist(), "attention_mask": final_mask.tolist()} - - # Use tokenizer pad function - padded_output = self.paligemma_tokenizer.pad( - batch_inputs, padding="longest", max_length=180, return_tensors="pt" - ) - padded_mask = padded_output["attention_mask"] - - # define tensor of padding lengths - att_mask = (padded_mask != 0).cumsum(dim=1) > prefix_lens - - token_type_ids = self.create_token_type_ids(padded_mask=padded_mask, prefix_len=prefix_lens) - - padded_output["padded_mask"] = padded_output.pop("attention_mask") - padded_output["attention_mask"] = att_mask - # loss is computed not on prefix, and not on padding - padded_output["loss_mask"] = att_mask & padded_output["padded_mask"] - padded_output["token_type_ids"] = token_type_ids - return padded_output - - def shift_padding_side( - self, - tokens: torch.Tensor, - ar_mask: torch.Tensor, - padding_mask: torch.Tensor, - loss_mask: torch.Tensor, - targets: torch.Tensor, - token_type_ids: torch.Tensor, - padding_side: str = "right", - ) -> tuple[torch.Tensor]: - if padding_side not in ["right", "left"]: - return tokens, ar_mask, padding_mask, loss_mask, targets, token_type_ids - - new_tokens = torch.empty_like(tokens) - new_ar_masks = torch.empty_like(ar_mask) - new_padding_mask = torch.empty_like(padding_mask) - new_loss_mask = torch.empty_like(loss_mask) - new_targets = torch.empty_like(targets) - new_token_type_ids = torch.empty_like(token_type_ids) - batch_size = tokens.shape[0] - for i in range(batch_size): - padding_indices = torch.where(padding_mask[i] == 0)[0] - non_padding_indices = torch.where(padding_mask[i] == 1)[0] - if padding_side == "left": - new_indices = torch.cat((padding_indices, non_padding_indices), dim=0) - else: - new_indices = torch.cat((non_padding_indices, padding_indices), dim=0) - new_tokens[i] = tokens[i].index_select(0, new_indices) - new_ar_masks[i] = ar_mask[i].index_select(0, new_indices) - new_padding_mask[i] = padding_mask[i].index_select(0, new_indices) - new_loss_mask[i] = loss_mask[i].index_select(0, new_indices) - new_targets[i] = targets[i].index_select(0, new_indices) - new_token_type_ids[i] = token_type_ids[i].index_select(0, new_indices) - - return new_tokens, new_ar_masks, new_padding_mask, new_loss_mask, new_targets, new_token_type_ids - - def forward(self, batch: dict[str, Tensor]): - device = batch[OBS_STATE].device - # TODO: keep like this or move to the policy .forward - images, img_masks = self.prepare_images(batch) - - padded_outs = self.create_input_tokens( - state=batch[OBS_STATE], - lang_text=batch["task"], - actions=batch[ACTION], - ) - - embs, pad_masks, _, targets, loss_mask, token_type_ids = self.embed_inputs( - images, - img_masks, - padded_outs["input_ids"], - padded_outs["padded_mask"], - padded_outs["attention_mask"], - padded_outs["loss_mask"], - padded_outs["token_type_ids"], - padding_side=self.padding_side, - ) - position_ids = torch.cumsum(pad_masks, dim=1) - 1 - token_type_ids = token_type_ids.to(dtype=torch.int64) - past_seen_tokens = 0 - cache_position = torch.arange(past_seen_tokens, past_seen_tokens + embs.shape[1], device=embs.device) - pad_masks = block_causal_update_causal_mask( - attention_mask=pad_masks, - past_key_values=None, - cache_position=cache_position, - input_tensor=embs, - token_type_ids=token_type_ids, - dtype=self.pi0_paligemma.dtype, - attn_implementation=self.pi0_paligemma.config.text_config._attn_implementation, - ) - outputs = self.pi0_paligemma.forward( - input_ids=None, - token_type_ids=None, - attention_mask=pad_masks, - position_ids=position_ids, - past_key_values=None, - inputs_embeds=embs, - use_cache=False, - labels=None, - ) - - logits = outputs.logits - - loss_fct = nn.CrossEntropyLoss(reduction="none") - - # Shift left for next-step prediction - logits = logits[:, :-1, :] - targets = targets[:, 1:].to(device) # Shift targets - loss_mask = loss_mask[:, 1:].to(device) # Ensure correct shape - - # Compute per-token loss - token_loss = loss_fct(logits.reshape(-1, logits.shape[-1]), targets.reshape(-1)) - - # Apply loss mask - token_loss = token_loss * loss_mask.reshape(-1) - - # Compute final loss - loss = token_loss.sum() / torch.clamp(loss_mask.sum(), min=1) - - # Return loss dictionary - loss_dict = {"ce_loss": loss.item(), "loss": loss} - return loss_dict - - def decode_actions_with_fast( - self, - tokens: list[list[int]], - *, - time_horizon: int | None = None, - action_dim: int | None = None, - relaxed_decoding: bool = True, - ) -> np.array: - """ - Adapt original decoding in FAST to always return actions instead of zeros. - """ - self.time_horizon = ( - time_horizon or self.fast_tokenizer.time_horizon or self.fast_tokenizer.called_time_horizon - ) - self.action_dim = ( - action_dim or self.fast_tokenizer.action_dim or self.fast_tokenizer.called_action_dim - ) - - # Cache the time horizon and action dimension for the next call - self.called_time_horizon = self.time_horizon - self.called_action_dim = self.action_dim - - assert self.time_horizon is not None and self.action_dim is not None, ( - "Tokenizer not initialized, call encode() once or pass in time_horizon and action_dim." - ) - - decoded_actions = [] - for token in tokens: - try: - decoded_tokens = self.fast_tokenizer.bpe_tokenizer.decode(token) - decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.fast_tokenizer.min_token - if relaxed_decoding: - # Expected sequence length - expected_seq_len = self.time_horizon * self.action_dim - diff = expected_seq_len - decoded_dct_coeff.shape[0] - # Apply truncation if too long - if diff < 0: - decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # Truncate on the right - # Apply padding if too short - elif diff > 0: - decoded_dct_coeff = np.pad( - decoded_dct_coeff, (0, diff), mode="constant", constant_values=0 - ) - - decoded_dct_coeff = decoded_dct_coeff.reshape(-1, self.action_dim) - assert decoded_dct_coeff.shape == ( - self.time_horizon, - self.action_dim, - ), ( - f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({self.time_horizon}, {self.action_dim})" - ) - except Exception as e: - print(f"Error decoding tokens: {e}") - print(f"Tokens: {token}") - decoded_dct_coeff = np.zeros((self.time_horizon, self.action_dim)) - decoded_actions.append(idct(decoded_dct_coeff / self.fast_tokenizer.scale, axis=0, norm="ortho")) - return np.stack(decoded_actions) - - def extract_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor: - """ - Extracts actions from predicted output tokens using the FAST model. - - Args: - tokens (torch.Tensor): The input tensor of tokenized outputs. - action_horizon (int): The number of timesteps for actions. - action_dim (int): The dimensionality of each action. - - Returns: - torch.Tensor: The extracted actions as a tensor of shape (action_horizon, action_dim). - """ - # Decode predicted output tokens - decoded_tokens = self.paligemma_tokenizer.batch_decode(tokens, skip_special_tokens=True) - cleaned_tokens = [ - tokens_sequence.replace("Action:", "").replace(":", "").strip().split("|")[0].strip() - for tokens_sequence in decoded_tokens - ] - raw_action_tokens = [ - self.processor.tokenizer.encode(sample_tokens, return_tensors="pt", padding=False) - for sample_tokens in cleaned_tokens - ] # something like this should be robust #looks good - action_tokens = [ - self._act_tokens_to_paligemma_tokens(raw_action_token) for raw_action_token in raw_action_tokens - ] - # returns the tensor of decoded actions per sample in a list - decoded_actions = [ - torch.tensor( - self.decode_actions_with_fast( - tok.tolist(), - time_horizon=action_horizon, - action_dim=action_dim, - relaxed_decoding=self.config.relaxed_action_decoding, - ), - device=tokens.device, - ).squeeze(0) - for tok in action_tokens - ] - - return torch.stack( - decoded_actions, - dim=0, - ) - - def generate_actions(self, batch: dict[str, Tensor]): - # TODO: keep like this or move to the policy .forward - images, img_masks = self.prepare_images(batch) - - padded_outs = self.create_input_tokens(state=batch[OBS_STATE], lang_text=batch["task"], actions=None) - embs, pad_masks, att_masks2, targets, loss_mask, token_type_ids = self.embed_inputs( - images, - img_masks, - padded_outs["input_ids"], - padded_outs["padded_mask"], - padded_outs["attention_mask"], - padded_outs["loss_mask"], - padded_outs["token_type_ids"], - padding_side="left", - ) - token_type_ids = token_type_ids.to(dtype=torch.int64) - prefix_position_ids = torch.cumsum(pad_masks, dim=1) - 1 - output_tokens = self.pi0_paligemma.generate( - input_ids=None, - attention_mask=pad_masks, - position_ids=prefix_position_ids, - past_key_values=None, - inputs_embeds=embs, - use_cache=self.config.use_cache, - max_new_tokens=self.config.max_decoding_steps, - do_sample=False, - num_beams=1, - token_type_ids=token_type_ids, - ) - actions = self.extract_actions(output_tokens, self.action_horizon, self.action_dim) - return actions - - def embed_image(self, image: torch.Tensor): - # Handle different transformers versions - if hasattr(self.pi0_paligemma, "get_image_features"): - return self.pi0_paligemma.get_image_features(image) - else: - return self.pi0_paligemma.model.get_image_features(image) - - def embed_inputs( - self, - images, - img_masks, - tokens, - pad_mask, - ar_mask, - loss_mask, - token_type_ids, - padding_side: str = "right", - ): - # TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty - # images are a list of same size - # vectorizing everything! - device = images[0].device - image_embedding_dim = images[0].shape[-1] # TODO should be from self.config - all_images = torch.stack(images, dim=1).to(device) - b, n, c, h, w = all_images.shape - all_images = all_images.view(b * n, c, h, w) - embedded = self.embed_image(all_images).to(device) - b_n, p, image_embedding_dim = embedded.shape # Extract current dimensions - m = b_n // b # Compute the number of images per sample dynamically - - # Reshape dynamically - embedded = embedded.view(b, m, p, image_embedding_dim) - tokens_embs = self.embed_tokens(tokens.to(device)) - - img_masks = torch.stack(img_masks, dim=1).unsqueeze(-1).to(device) - num_img_emb = embedded.shape[2] - img_pad_masks = img_masks.repeat(1, 1, num_img_emb).view(b, -1) - img_att_masks = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1) - - image_target_tokens = ( - torch.ones((b, n, num_img_emb), dtype=torch.long, device=device) * self.pad_token_id - ).reshape(b, -1) - image_loss_mask = torch.zeros((b, n, num_img_emb), dtype=torch.long, device=device).reshape(b, -1) - - embedded = embedded.reshape(b, n * num_img_emb, image_embedding_dim) # Shape: (B, N*P, D) - - embs = torch.cat([embedded, tokens_embs], dim=1).to(device) - pad_masks = torch.cat([img_pad_masks, pad_mask.to(device)], dim=1) - att_masks = torch.cat([img_att_masks, ar_mask.to(device)], dim=1) - loss_masks = torch.cat([image_loss_mask, loss_mask.to(device)], dim=1) - targets = torch.cat([image_target_tokens, tokens.to(device)], dim=1) - token_type_ids = torch.cat([img_att_masks, token_type_ids.to(device)], dim=1) - - # Shift pad tokens to the left (.generate()) or right (.train()) - embs, att_masks, pad_masks, loss_masks, targets, token_type_ids = self.shift_padding_side( - embs, att_masks, pad_masks, loss_masks, targets, token_type_ids, padding_side=padding_side - ) - - targets = torch.where(targets == self.pad_token_id, self.ignore_index, targets) - return embs, pad_masks, att_masks, targets, loss_masks, token_type_ids - - -def resize_with_pad(img, width, height, pad_value=0, interpolate_like_pi=True): - # assume no-op when width height fits already - if img.ndim != 4: - raise ValueError(f"(b,c,h,w) expected, but {img.shape}") - - cur_height, cur_width = img.shape[2:] - - ratio = max(cur_width / width, cur_height / height) - resized_height = int(cur_height / ratio) - resized_width = int(cur_width / ratio) - - if interpolate_like_pi: - img = (img * 255.0).to(dtype=torch.uint8) - img = img.permute(0, 2, 3, 1) - original_device = img.device - img = img.to(device="cpu").numpy() - imgs = [] - for sub_img in img: - sub_img = Image.fromarray(sub_img) - resized_img = sub_img.resize((resized_width, resized_height), resample=2) - resized_img = torch.from_numpy(np.array(resized_img)) - imgs.append(resized_img) - img = torch.stack(imgs, dim=0) - img = img.permute(0, 3, 1, 2) - resized_img = img.to(device=original_device, dtype=torch.float32) / 255.0 - else: - resized_img = F.interpolate( - img, size=(resized_height, resized_width), mode="bilinear", align_corners=False - ) - - pad_height = max(0, int(height - resized_height)) - pad_width = max(0, int(width - resized_width)) - - # pad on left and top of image - padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) - return padded_img diff --git a/src/lerobot/datasets/v21/convert_dataset_v20_to_v21.py b/src/lerobot/datasets/v21/convert_dataset_v20_to_v21.py index 27cd56e6f..ae94c4e02 100644 --- a/src/lerobot/datasets/v21/convert_dataset_v20_to_v21.py +++ b/src/lerobot/datasets/v21/convert_dataset_v20_to_v21.py @@ -25,7 +25,7 @@ This script will help you convert any LeRobot dataset already pushed to the hub Usage: ```bash -python lerobot/common/datasets/v21/convert_dataset_v20_to_v21.py \ +python -m lerobot.datasets.v21.convert_dataset_v20_to_v21 \ --repo-id=aliberts/koch_tutorial ``` diff --git a/src/lerobot/find_port.py b/src/lerobot/find_port.py index e69de29bb..cf0282507 100644 --- a/src/lerobot/find_port.py +++ b/src/lerobot/find_port.py @@ -0,0 +1,65 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Helper to find the USB port associated with your MotorsBus. + +Example: + +```shell +python -m lerobot.find_port +``` +""" + +import platform +import time +from pathlib import Path + + +def find_available_ports(): + from serial.tools import list_ports # Part of pyserial library + + if platform.system() == "Windows": + # List COM ports using pyserial + ports = [port.device for port in list_ports.comports()] + else: # Linux/macOS + # List /dev/tty* ports for Unix-based systems + ports = [str(path) for path in Path("/dev").glob("tty*")] + return ports + + +def find_port(): + print("Finding all available ports for the MotorsBus.") + ports_before = find_available_ports() + print("Ports before disconnecting:", ports_before) + + print("Remove the USB cable from your MotorsBus and press Enter when done.") + input() # Wait for user to disconnect the device + + time.sleep(0.5) # Allow some time for port to be released + ports_after = find_available_ports() + ports_diff = list(set(ports_before) - set(ports_after)) + + if len(ports_diff) == 1: + port = ports_diff[0] + print(f"The port of this MotorsBus is '{port}'") + print("Reconnect the USB cable.") + elif len(ports_diff) == 0: + raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).") + else: + raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).") + + +if __name__ == "__main__": + find_port() diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index e69de29bb..ed911e9be 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -0,0 +1,769 @@ +#!/usr/bin/env python + +# Copyright 2024 Tony Z. Zhao and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Action Chunking Transformer Policy + +As per Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware (https://huggingface.co/papers/2304.13705). +The majority of changes here involve removing unused code, unifying naming, and adding helpful comments. +""" + +import math +from collections import deque +from itertools import chain +from typing import Callable + +import einops +import numpy as np +import torch +import torch.nn.functional as F # noqa: N812 +import torchvision +from torch import Tensor, nn +from torchvision.models._utils import IntermediateLayerGetter +from torchvision.ops.misc import FrozenBatchNorm2d + +from lerobot.constants import ACTION, OBS_IMAGES +from lerobot.policies.act.configuration_act import ACTConfig +from lerobot.policies.normalize import Normalize, Unnormalize +from lerobot.policies.pretrained import PreTrainedPolicy + + +class ACTPolicy(PreTrainedPolicy): + """ + Action Chunking Transformer Policy as per Learning Fine-Grained Bimanual Manipulation with Low-Cost + Hardware (paper: https://huggingface.co/papers/2304.13705, code: https://github.com/tonyzhaozh/act) + """ + + config_class = ACTConfig + name = "act" + + def __init__( + self, + config: ACTConfig, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, + ): + """ + Args: + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. + """ + super().__init__(config) + config.validate_features() + self.config = config + + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + self.model = ACT(config) + + if config.temporal_ensemble_coeff is not None: + self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size) + + self.reset() + + def get_optim_params(self) -> dict: + # TODO(aliberts, rcadene): As of now, lr_backbone == lr + # Should we remove this and just `return self.parameters()`? + return [ + { + "params": [ + p + for n, p in self.named_parameters() + if not n.startswith("model.backbone") and p.requires_grad + ] + }, + { + "params": [ + p + for n, p in self.named_parameters() + if n.startswith("model.backbone") and p.requires_grad + ], + "lr": self.config.optimizer_lr_backbone, + }, + ] + + def reset(self): + """This should be called whenever the environment is reset.""" + if self.config.temporal_ensemble_coeff is not None: + self.temporal_ensembler.reset() + else: + self._action_queue = deque([], maxlen=self.config.n_action_steps) + + @torch.no_grad + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations. + + This method wraps `select_actions` in order to return one action at a time for execution in the + environment. It works by managing the actions in a queue and only calling `select_actions` when the + queue is empty. + """ + self.eval() # keeping the policy in eval mode as it could be set to train mode while queue is consumed + + if self.config.temporal_ensemble_coeff is not None: + actions = self.predict_action_chunk(batch) + action = self.temporal_ensembler.update(actions) + return action + + # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by + # querying the policy. + if len(self._action_queue) == 0: + actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] + + # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue + # effectively has shape (n_action_steps, batch_size, *), hence the transpose. + self._action_queue.extend(actions.transpose(0, 1)) + return self._action_queue.popleft() + + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + self.eval() + + batch = self.normalize_inputs(batch) + if self.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features] + + actions = self.model(batch)[0] + actions = self.unnormalize_outputs({ACTION: actions})[ACTION] + return actions + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: + """Run the batch through the model and compute the loss for training or validation.""" + batch = self.normalize_inputs(batch) + if self.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch[OBS_IMAGES] = [batch[key] for key in self.config.image_features] + + batch = self.normalize_targets(batch) + actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) + + l1_loss = ( + F.l1_loss(batch[ACTION], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) + ).mean() + + loss_dict = {"l1_loss": l1_loss.item()} + if self.config.use_vae: + # Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for + # each dimension independently, we sum over the latent dimension to get the total + # KL-divergence per batch element, then take the mean over the batch. + # (See App. B of https://huggingface.co/papers/1312.6114 for more details). + mean_kld = ( + (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() + ) + loss_dict["kld_loss"] = mean_kld.item() + loss = l1_loss + mean_kld * self.config.kl_weight + else: + loss = l1_loss + + return loss, loss_dict + + +class ACTTemporalEnsembler: + def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None: + """Temporal ensembling as described in Algorithm 2 of https://huggingface.co/papers/2304.13705. + + The weights are calculated as wᵢ = exp(-temporal_ensemble_coeff * i) where w₀ is the oldest action. + They are then normalized to sum to 1 by dividing by Σwᵢ. Here's some intuition around how the + coefficient works: + - Setting it to 0 uniformly weighs all actions. + - Setting it positive gives more weight to older actions. + - Setting it negative gives more weight to newer actions. + NOTE: The default value for `temporal_ensemble_coeff` used by the original ACT work is 0.01. This + results in older actions being weighed more highly than newer actions (the experiments documented in + https://github.com/huggingface/lerobot/pull/319 hint at why highly weighing new actions might be + detrimental: doing so aggressively may diminish the benefits of action chunking). + + Here we use an online method for computing the average rather than caching a history of actions in + order to compute the average offline. For a simple 1D sequence it looks something like: + + ``` + import torch + + seq = torch.linspace(8, 8.5, 100) + print(seq) + + m = 0.01 + exp_weights = torch.exp(-m * torch.arange(len(seq))) + print(exp_weights) + + # Calculate offline + avg = (exp_weights * seq).sum() / exp_weights.sum() + print("offline", avg) + + # Calculate online + for i, item in enumerate(seq): + if i == 0: + avg = item + continue + avg *= exp_weights[:i].sum() + avg += item * exp_weights[i] + avg /= exp_weights[:i+1].sum() + print("online", avg) + ``` + """ + self.chunk_size = chunk_size + self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)) + self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0) + self.reset() + + def reset(self): + """Resets the online computation variables.""" + self.ensembled_actions = None + # (chunk_size,) count of how many actions are in the ensemble for each time step in the sequence. + self.ensembled_actions_count = None + + def update(self, actions: Tensor) -> Tensor: + """ + Takes a (batch, chunk_size, action_dim) sequence of actions, update the temporal ensemble for all + time steps, and pop/return the next batch of actions in the sequence. + """ + self.ensemble_weights = self.ensemble_weights.to(device=actions.device) + self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device) + if self.ensembled_actions is None: + # Initializes `self._ensembled_action` to the sequence of actions predicted during the first + # time step of the episode. + self.ensembled_actions = actions.clone() + # Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor + # operations later. + self.ensembled_actions_count = torch.ones( + (self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device + ) + else: + # self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute + # the online update for those entries. + self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1] + self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count] + self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count] + self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size) + # The last action, which has no prior online average, needs to get concatenated onto the end. + self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1) + self.ensembled_actions_count = torch.cat( + [self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])] + ) + # "Consume" the first action. + action, self.ensembled_actions, self.ensembled_actions_count = ( + self.ensembled_actions[:, 0], + self.ensembled_actions[:, 1:], + self.ensembled_actions_count[1:], + ) + return action + + +class ACT(nn.Module): + """Action Chunking Transformer: The underlying neural network for ACTPolicy. + + Note: In this code we use the terms `vae_encoder`, 'encoder', `decoder`. The meanings are as follows. + - The `vae_encoder` is, as per the literature around variational auto-encoders (VAE), the part of the + model that encodes the target data (a sequence of actions), and the condition (the robot + joint-space). + - A transformer with an `encoder` (not the VAE encoder) and `decoder` (not the VAE decoder) with + cross-attention is used as the VAE decoder. For these terms, we drop the `vae_` prefix because we + have an option to train this model without the variational objective (in which case we drop the + `vae_encoder` altogether, and nothing about this model has anything to do with a VAE). + + Transformer + Used alone for inference + (acts as VAE decoder + during training) + ┌───────────────────────┐ + │ Outputs │ + │ ▲ │ + │ ┌─────►┌───────┐ │ + ┌──────┐ │ │ │Transf.│ │ + │ │ │ ├─────►│decoder│ │ + ┌────┴────┐ │ │ │ │ │ │ + │ │ │ │ ┌───┴───┬─►│ │ │ + │ VAE │ │ │ │ │ └───────┘ │ + │ encoder │ │ │ │Transf.│ │ + │ │ │ │ │encoder│ │ + └───▲─────┘ │ │ │ │ │ + │ │ │ └▲──▲─▲─┘ │ + │ │ │ │ │ │ │ + inputs └─────┼──┘ │ image emb. │ + │ state emb. │ + └───────────────────────┘ + """ + + def __init__(self, config: ACTConfig): + # BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence]. + # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). + super().__init__() + self.config = config + + if self.config.use_vae: + self.vae_encoder = ACTEncoder(config, is_vae_encoder=True) + self.vae_encoder_cls_embed = nn.Embedding(1, config.dim_model) + # Projection layer for joint-space configuration to hidden dimension. + if self.config.robot_state_feature: + self.vae_encoder_robot_state_input_proj = nn.Linear( + self.config.robot_state_feature.shape[0], config.dim_model + ) + # Projection layer for action (joint-space target) to hidden dimension. + self.vae_encoder_action_input_proj = nn.Linear( + self.config.action_feature.shape[0], + config.dim_model, + ) + # Projection layer from the VAE encoder's output to the latent distribution's parameter space. + self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2) + # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch + # dimension. + num_input_token_encoder = 1 + config.chunk_size + if self.config.robot_state_feature: + num_input_token_encoder += 1 + self.register_buffer( + "vae_encoder_pos_enc", + create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0), + ) + + # Backbone for image feature extraction. + if self.config.image_features: + backbone_model = getattr(torchvision.models, config.vision_backbone)( + replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation], + weights=config.pretrained_backbone_weights, + norm_layer=FrozenBatchNorm2d, + ) + # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final + # feature map). + # Note: The forward method of this returns a dict: {"feature_map": output}. + self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) + + # Transformer (acts as VAE decoder when training with the variational objective). + self.encoder = ACTEncoder(config) + self.decoder = ACTDecoder(config) + + # Transformer encoder input projections. The tokens will be structured like + # [latent, (robot_state), (env_state), (image_feature_map_pixels)]. + if self.config.robot_state_feature: + self.encoder_robot_state_input_proj = nn.Linear( + self.config.robot_state_feature.shape[0], config.dim_model + ) + if self.config.env_state_feature: + self.encoder_env_state_input_proj = nn.Linear( + self.config.env_state_feature.shape[0], config.dim_model + ) + self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model) + if self.config.image_features: + self.encoder_img_feat_input_proj = nn.Conv2d( + backbone_model.fc.in_features, config.dim_model, kernel_size=1 + ) + # Transformer encoder positional embeddings. + n_1d_tokens = 1 # for the latent + if self.config.robot_state_feature: + n_1d_tokens += 1 + if self.config.env_state_feature: + n_1d_tokens += 1 + self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model) + if self.config.image_features: + self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) + + # Transformer decoder. + # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). + self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model) + + # Final action regression head on the output of the transformer's decoder. + self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0]) + + self._reset_parameters() + + def _reset_parameters(self): + """Xavier-uniform initialization of the transformer parameters as in the original code.""" + for p in chain(self.encoder.parameters(), self.decoder.parameters()): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]: + """A forward pass through the Action Chunking Transformer (with optional VAE encoder). + + `batch` should have the following structure: + { + [robot_state_feature] (optional): (B, state_dim) batch of robot states. + + [image_features]: (B, n_cameras, C, H, W) batch of images. + AND/OR + [env_state_feature]: (B, env_dim) batch of environment states. + + [action_feature] (optional, only if training with VAE): (B, chunk_size, action dim) batch of actions. + } + + Returns: + (B, chunk_size, action_dim) batch of action sequences + Tuple containing the latent PDF's parameters (mean, log(σ²)) both as (B, L) tensors where L is the + latent dimension. + """ + if self.config.use_vae and self.training: + assert "action" in batch, ( + "actions must be provided when using the variational objective in training mode." + ) + + if "observation.images" in batch: + batch_size = batch["observation.images"][0].shape[0] + else: + batch_size = batch["observation.environment_state"].shape[0] + + # Prepare the latent for input to the transformer encoder. + if self.config.use_vae and "action" in batch: + # Prepare the input to the VAE encoder: [cls, *joint_space_configuration, *action_sequence]. + cls_embed = einops.repeat( + self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size + ) # (B, 1, D) + if self.config.robot_state_feature: + robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) + robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) + action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) + + if self.config.robot_state_feature: + vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D) + else: + vae_encoder_input = [cls_embed, action_embed] + vae_encoder_input = torch.cat(vae_encoder_input, axis=1) + + # Prepare fixed positional embedding. + # Note: detach() shouldn't be necessary but leaving it the same as the original code just in case. + pos_embed = self.vae_encoder_pos_enc.clone().detach() # (1, S+2, D) + + # Prepare key padding mask for the transformer encoder. We have 1 or 2 extra tokens at the start of the + # sequence depending whether we use the input states or not (cls and robot state) + # False means not a padding token. + cls_joint_is_pad = torch.full( + (batch_size, 2 if self.config.robot_state_feature else 1), + False, + device=batch["observation.state"].device, + ) + key_padding_mask = torch.cat( + [cls_joint_is_pad, batch["action_is_pad"]], axis=1 + ) # (bs, seq+1 or 2) + + # Forward pass through VAE encoder to get the latent PDF parameters. + cls_token_out = self.vae_encoder( + vae_encoder_input.permute(1, 0, 2), + pos_embed=pos_embed.permute(1, 0, 2), + key_padding_mask=key_padding_mask, + )[0] # select the class token, with shape (B, D) + latent_pdf_params = self.vae_encoder_latent_output_proj(cls_token_out) + mu = latent_pdf_params[:, : self.config.latent_dim] + # This is 2log(sigma). Done this way to match the original implementation. + log_sigma_x2 = latent_pdf_params[:, self.config.latent_dim :] + + # Sample the latent with the reparameterization trick. + latent_sample = mu + log_sigma_x2.div(2).exp() * torch.randn_like(mu) + else: + # When not using the VAE encoder, we set the latent to be all zeros. + mu = log_sigma_x2 = None + # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer + latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to( + batch["observation.state"].device + ) + + # Prepare transformer encoder inputs. + encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)] + encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)) + # Robot state token. + if self.config.robot_state_feature: + encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"])) + # Environment state token. + if self.config.env_state_feature: + encoder_in_tokens.append( + self.encoder_env_state_input_proj(batch["observation.environment_state"]) + ) + + # Camera observation features and positional embeddings. + if self.config.image_features: + all_cam_features = [] + all_cam_pos_embeds = [] + + # For a list of images, the H and W may vary but H*W is constant. + for img in batch["observation.images"]: + cam_features = self.backbone(img)["feature_map"] + cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) + cam_features = self.encoder_img_feat_input_proj(cam_features) + + # Rearrange features to (sequence, batch, dim). + cam_features = einops.rearrange(cam_features, "b c h w -> (h w) b c") + cam_pos_embed = einops.rearrange(cam_pos_embed, "b c h w -> (h w) b c") + + all_cam_features.append(cam_features) + all_cam_pos_embeds.append(cam_pos_embed) + + encoder_in_tokens.extend(torch.cat(all_cam_features, axis=0)) + encoder_in_pos_embed.extend(torch.cat(all_cam_pos_embeds, axis=0)) + + # Stack all tokens along the sequence dimension. + encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0) + encoder_in_pos_embed = torch.stack(encoder_in_pos_embed, axis=0) + + # Forward pass through the transformer modules. + encoder_out = self.encoder(encoder_in_tokens, pos_embed=encoder_in_pos_embed) + # TODO(rcadene, alexander-soare): remove call to `device` ; precompute and use buffer + decoder_in = torch.zeros( + (self.config.chunk_size, batch_size, self.config.dim_model), + dtype=encoder_in_pos_embed.dtype, + device=encoder_in_pos_embed.device, + ) + decoder_out = self.decoder( + decoder_in, + encoder_out, + encoder_pos_embed=encoder_in_pos_embed, + decoder_pos_embed=self.decoder_pos_embed.weight.unsqueeze(1), + ) + + # Move back to (B, S, C). + decoder_out = decoder_out.transpose(0, 1) + + actions = self.action_head(decoder_out) + + return actions, (mu, log_sigma_x2) + + +class ACTEncoder(nn.Module): + """Convenience module for running multiple encoder layers, maybe followed by normalization.""" + + def __init__(self, config: ACTConfig, is_vae_encoder: bool = False): + super().__init__() + self.is_vae_encoder = is_vae_encoder + num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers + self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)]) + self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity() + + def forward( + self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None + ) -> Tensor: + for layer in self.layers: + x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask) + x = self.norm(x) + return x + + +class ACTEncoderLayer(nn.Module): + def __init__(self, config: ACTConfig): + super().__init__() + self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) + + # Feed forward layers. + self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) + self.dropout = nn.Dropout(config.dropout) + self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model) + + self.norm1 = nn.LayerNorm(config.dim_model) + self.norm2 = nn.LayerNorm(config.dim_model) + self.dropout1 = nn.Dropout(config.dropout) + self.dropout2 = nn.Dropout(config.dropout) + + self.activation = get_activation_fn(config.feedforward_activation) + self.pre_norm = config.pre_norm + + def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor: + skip = x + if self.pre_norm: + x = self.norm1(x) + q = k = x if pos_embed is None else x + pos_embed + x = self.self_attn(q, k, value=x, key_padding_mask=key_padding_mask) + x = x[0] # note: [0] to select just the output, not the attention weights + x = skip + self.dropout1(x) + if self.pre_norm: + skip = x + x = self.norm2(x) + else: + x = self.norm1(x) + skip = x + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + x = skip + self.dropout2(x) + if not self.pre_norm: + x = self.norm2(x) + return x + + +class ACTDecoder(nn.Module): + def __init__(self, config: ACTConfig): + """Convenience module for running multiple decoder layers followed by normalization.""" + super().__init__() + self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]) + self.norm = nn.LayerNorm(config.dim_model) + + def forward( + self, + x: Tensor, + encoder_out: Tensor, + decoder_pos_embed: Tensor | None = None, + encoder_pos_embed: Tensor | None = None, + ) -> Tensor: + for layer in self.layers: + x = layer( + x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed + ) + if self.norm is not None: + x = self.norm(x) + return x + + +class ACTDecoderLayer(nn.Module): + def __init__(self, config: ACTConfig): + super().__init__() + self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) + self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) + + # Feed forward layers. + self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) + self.dropout = nn.Dropout(config.dropout) + self.linear2 = nn.Linear(config.dim_feedforward, config.dim_model) + + self.norm1 = nn.LayerNorm(config.dim_model) + self.norm2 = nn.LayerNorm(config.dim_model) + self.norm3 = nn.LayerNorm(config.dim_model) + self.dropout1 = nn.Dropout(config.dropout) + self.dropout2 = nn.Dropout(config.dropout) + self.dropout3 = nn.Dropout(config.dropout) + + self.activation = get_activation_fn(config.feedforward_activation) + self.pre_norm = config.pre_norm + + def maybe_add_pos_embed(self, tensor: Tensor, pos_embed: Tensor | None) -> Tensor: + return tensor if pos_embed is None else tensor + pos_embed + + def forward( + self, + x: Tensor, + encoder_out: Tensor, + decoder_pos_embed: Tensor | None = None, + encoder_pos_embed: Tensor | None = None, + ) -> Tensor: + """ + Args: + x: (Decoder Sequence, Batch, Channel) tensor of input tokens. + encoder_out: (Encoder Sequence, B, C) output features from the last layer of the encoder we are + cross-attending with. + decoder_pos_embed: (ES, 1, C) positional embedding for keys (from the encoder). + encoder_pos_embed: (DS, 1, C) Positional_embedding for the queries (from the decoder). + Returns: + (DS, B, C) tensor of decoder output features. + """ + skip = x + if self.pre_norm: + x = self.norm1(x) + q = k = self.maybe_add_pos_embed(x, decoder_pos_embed) + x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights + x = skip + self.dropout1(x) + if self.pre_norm: + skip = x + x = self.norm2(x) + else: + x = self.norm1(x) + skip = x + x = self.multihead_attn( + query=self.maybe_add_pos_embed(x, decoder_pos_embed), + key=self.maybe_add_pos_embed(encoder_out, encoder_pos_embed), + value=encoder_out, + )[0] # select just the output, not the attention weights + x = skip + self.dropout2(x) + if self.pre_norm: + skip = x + x = self.norm3(x) + else: + x = self.norm2(x) + skip = x + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + x = skip + self.dropout3(x) + if not self.pre_norm: + x = self.norm3(x) + return x + + +def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tensor: + """1D sinusoidal positional embeddings as in Attention is All You Need. + + Args: + num_positions: Number of token positions required. + Returns: (num_positions, dimension) position embeddings (the first dimension is the batch dimension). + + """ + + def get_position_angle_vec(position): + return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)] + + sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)]) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + return torch.from_numpy(sinusoid_table).float() + + +class ACTSinusoidalPositionEmbedding2d(nn.Module): + """2D sinusoidal positional embeddings similar to what's presented in Attention Is All You Need. + + The variation is that the position indices are normalized in [0, 2π] (not quite: the lower bound is 1/H + for the vertical direction, and 1/W for the horizontal direction. + """ + + def __init__(self, dimension: int): + """ + Args: + dimension: The desired dimension of the embeddings. + """ + super().__init__() + self.dimension = dimension + self._two_pi = 2 * math.pi + self._eps = 1e-6 + # Inverse "common ratio" for the geometric progression in sinusoid frequencies. + self._temperature = 10000 + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x: A (B, C, H, W) batch of 2D feature map to generate the embeddings for. + Returns: + A (1, C, H, W) batch of corresponding sinusoidal positional embeddings. + """ + not_mask = torch.ones_like(x[0, :1]) # (1, H, W) + # Note: These are like range(1, H+1) and range(1, W+1) respectively, but in most implementations + # they would be range(0, H) and range(0, W). Keeping it at as is to match the original code. + y_range = not_mask.cumsum(1, dtype=torch.float32) + x_range = not_mask.cumsum(2, dtype=torch.float32) + + # "Normalize" the position index such that it ranges in [0, 2π]. + # Note: Adding epsilon on the denominator should not be needed as all values of y_embed and x_range + # are non-zero by construction. This is an artifact of the original code. + y_range = y_range / (y_range[:, -1:, :] + self._eps) * self._two_pi + x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi + + inverse_frequency = self._temperature ** ( + 2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension + ) + + x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) + y_range = y_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) + + # Note: this stack then flatten operation results in interleaved sine and cosine terms. + # pos_embed_x and pos_embed_y are (1, H, W, C // 2). + pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3) + pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3) + pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W) + + return pos_embed + + +def get_activation_fn(activation: str) -> Callable: + """Return an activation function given a string.""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu/glu, not {activation}.") diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index e69de29bb..ef56bdb61 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -0,0 +1,178 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from torch import nn + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType +from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata +from lerobot.datasets.utils import dataset_to_policy_features +from lerobot.envs.configs import EnvConfig +from lerobot.envs.utils import env_to_policy_features +from lerobot.policies.act.configuration_act import ACTConfig +from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig +from lerobot.policies.pi0.configuration_pi0 import PI0Config +from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.sac.configuration_sac import SACConfig +from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig +from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig +from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig +from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig + + +def get_policy_class(name: str) -> PreTrainedPolicy: + """Get the policy's class and config class given a name (matching the policy class' `name` attribute).""" + if name == "tdmpc": + from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy + + return TDMPCPolicy + elif name == "diffusion": + from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy + + return DiffusionPolicy + elif name == "act": + from lerobot.policies.act.modeling_act import ACTPolicy + + return ACTPolicy + elif name == "vqbet": + from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy + + return VQBeTPolicy + elif name == "pi0": + from lerobot.policies.pi0.modeling_pi0 import PI0Policy + + return PI0Policy + elif name == "pi0fast": + from lerobot.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy + + return PI0FASTPolicy + elif name == "sac": + from lerobot.policies.sac.modeling_sac import SACPolicy + + return SACPolicy + elif name == "reward_classifier": + from lerobot.policies.sac.reward_model.modeling_classifier import Classifier + + return Classifier + elif name == "smolvla": + from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy + + return SmolVLAPolicy + else: + raise NotImplementedError(f"Policy with name {name} is not implemented.") + + +def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: + if policy_type == "tdmpc": + return TDMPCConfig(**kwargs) + elif policy_type == "diffusion": + return DiffusionConfig(**kwargs) + elif policy_type == "act": + return ACTConfig(**kwargs) + elif policy_type == "vqbet": + return VQBeTConfig(**kwargs) + elif policy_type == "pi0": + return PI0Config(**kwargs) + elif policy_type == "pi0fast": + return PI0FASTConfig(**kwargs) + elif policy_type == "sac": + return SACConfig(**kwargs) + elif policy_type == "smolvla": + return SmolVLAConfig(**kwargs) + elif policy_type == "reward_classifier": + return RewardClassifierConfig(**kwargs) + else: + raise ValueError(f"Policy type '{policy_type}' is not available.") + + +def make_policy( + cfg: PreTrainedConfig, + ds_meta: LeRobotDatasetMetadata | None = None, + env_cfg: EnvConfig | None = None, +) -> PreTrainedPolicy: + """Make an instance of a policy class. + + This function exists because (for now) we need to parse features from either a dataset or an environment + in order to properly dimension and instantiate a policy for that dataset or environment. + + Args: + cfg (PreTrainedConfig): The config of the policy to make. If `pretrained_path` is set, the policy will + be loaded with the weights from that path. + ds_meta (LeRobotDatasetMetadata | None, optional): Dataset metadata to take input/output shapes and + statistics to use for (un)normalization of inputs/outputs in the policy. Defaults to None. + env_cfg (EnvConfig | None, optional): The config of a gym environment to parse features from. Must be + provided if ds_meta is not. Defaults to None. + + Raises: + ValueError: Either ds_meta or env and env_cfg must be provided. + NotImplementedError: if the policy.type is 'vqbet' and the policy device 'mps' (due to an incompatibility) + + Returns: + PreTrainedPolicy: _description_ + """ + if bool(ds_meta) == bool(env_cfg): + raise ValueError("Either one of a dataset metadata or a sim env must be provided.") + + # NOTE: Currently, if you try to run vqbet with mps backend, you'll get this error. + # TODO(aliberts, rcadene): Implement a check_backend_compatibility in policies? + # NotImplementedError: The operator 'aten::unique_dim' is not currently implemented for the MPS device. If + # you want this op to be added in priority during the prototype phase of this feature, please comment on + # https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment + # variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be + # slower than running natively on MPS. + if cfg.type == "vqbet" and cfg.device == "mps": + raise NotImplementedError( + "Current implementation of VQBeT does not support `mps` backend. " + "Please use `cpu` or `cuda` backend." + ) + + policy_cls = get_policy_class(cfg.type) + + kwargs = {} + if ds_meta is not None: + features = dataset_to_policy_features(ds_meta.features) + kwargs["dataset_stats"] = ds_meta.stats + else: + if not cfg.pretrained_path: + logging.warning( + "You are instantiating a policy from scratch and its features are parsed from an environment " + "rather than a dataset. Normalization modules inside the policy will have infinite values " + "by default without stats from a dataset." + ) + features = env_to_policy_features(env_cfg) + + cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION} + cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features} + kwargs["config"] = cfg + + if cfg.pretrained_path: + # Load a pretrained policy and override the config if needed (for example, if there are inference-time + # hyperparameters that we want to vary). + kwargs["pretrained_name_or_path"] = cfg.pretrained_path + policy = policy_cls.from_pretrained(**kwargs) + else: + # Make a fresh policy. + policy = policy_cls(**kwargs) + + policy.to(cfg.device) + assert isinstance(policy, nn.Module) + + # policy = torch.compile(policy, mode="reduce-overhead") + + return policy diff --git a/src/lerobot/policies/tdmpc/modeling_tdmpc.py b/src/lerobot/policies/tdmpc/modeling_tdmpc.py index e69de29bb..8b70b265d 100644 --- a/src/lerobot/policies/tdmpc/modeling_tdmpc.py +++ b/src/lerobot/policies/tdmpc/modeling_tdmpc.py @@ -0,0 +1,834 @@ +#!/usr/bin/env python + +# Copyright 2024 Nicklas Hansen, Xiaolong Wang, Hao Su, +# and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Implementation of Finetuning Offline World Models in the Real World. + +The comments in this code may sometimes refer to these references: + TD-MPC paper: Temporal Difference Learning for Model Predictive Control (https://huggingface.co/papers/2203.04955) + FOWM paper: Finetuning Offline World Models in the Real World (https://huggingface.co/papers/2310.16029) +""" + +# ruff: noqa: N806 + +from collections import deque +from copy import deepcopy +from functools import partial +from typing import Callable + +import einops +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +from torch import Tensor + +from lerobot.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_STATE, REWARD +from lerobot.policies.normalize import Normalize, Unnormalize +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig +from lerobot.policies.utils import get_device_from_parameters, get_output_shape, populate_queues + + +class TDMPCPolicy(PreTrainedPolicy): + """Implementation of TD-MPC learning + inference. + + Please note several warnings for this policy. + - Evaluation of pretrained weights created with the original FOWM code + (https://github.com/fyhMer/fowm) works as expected. To be precise: we trained and evaluated a + model with the FOWM code for the xarm_lift_medium_replay dataset. We ported the weights across + to LeRobot, and were able to evaluate with the same success metric. BUT, we had to use inter- + process communication to use the xarm environment from FOWM. This is because our xarm + environment uses newer dependencies and does not match the environment in FOWM. See + https://github.com/huggingface/lerobot/pull/103 for implementation details. + - We have NOT checked that training on LeRobot reproduces the results from FOWM. + - Nevertheless, we have verified that we can train TD-MPC for PushT. See + `lerobot/configs/policy/tdmpc_pusht_keypoints.yaml`. + - Our current xarm datasets were generated using the environment from FOWM. Therefore they do not + match our xarm environment. + """ + + config_class = TDMPCConfig + name = "tdmpc" + + def __init__(self, config: TDMPCConfig, dataset_stats: dict[str, dict[str, Tensor]] | None = None): + """ + Args: + config: Policy configuration class instance or None, in which case the default instantiation of + the configuration class is used. + dataset_stats: Dataset statistics to be used for normalization. If not passed here, it is expected + that they will be passed with a call to `load_state_dict` before the policy is used. + """ + super().__init__(config) + config.validate_features() + self.config = config + + self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats) + self.normalize_targets = Normalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + self.unnormalize_outputs = Unnormalize( + config.output_features, config.normalization_mapping, dataset_stats + ) + + self.model = TDMPCTOLD(config) + self.model_target = deepcopy(self.model) + for param in self.model_target.parameters(): + param.requires_grad = False + + self.reset() + + def get_optim_params(self) -> dict: + return self.parameters() + + def reset(self): + """ + Clear observation and action queues. Clear previous means for warm starting of MPPI/CEM. Should be + called on `env.reset()` + """ + self._queues = { + "observation.state": deque(maxlen=1), + "action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)), + } + if self.config.image_features: + self._queues["observation.image"] = deque(maxlen=1) + if self.config.env_state_feature: + self._queues["observation.environment_state"] = deque(maxlen=1) + # Previous mean obtained from the cross-entropy method (CEM) used during MPC. It is used to warm start + # CEM for the next step. + self._prev_mean: torch.Tensor | None = None + + @torch.no_grad + def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + batch = {key: torch.stack(list(self._queues[key]), dim=1) for key in batch if key in self._queues} + + # Remove the time dimensions as it is not handled yet. + for key in batch: + assert batch[key].shape[1] == 1 + batch[key] = batch[key][:, 0] + + # NOTE: Order of observations matters here. + encode_keys = [] + if self.config.image_features: + encode_keys.append(OBS_IMAGE) + if self.config.env_state_feature: + encode_keys.append(OBS_ENV_STATE) + encode_keys.append(OBS_STATE) + z = self.model.encode({k: batch[k] for k in encode_keys}) + if self.config.use_mpc: # noqa: SIM108 + actions = self.plan(z) # (horizon, batch, action_dim) + else: + # Plan with the policy (π) alone. This always returns one action so unsqueeze to get a + # sequence dimension like in the MPC branch. + actions = self.model.pi(z).unsqueeze(0) + + actions = torch.clamp(actions, -1, +1) + + actions = self.unnormalize_outputs({ACTION: actions})[ACTION] + return actions + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations.""" + batch = self.normalize_inputs(batch) + if self.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))] + + self._queues = populate_queues(self._queues, batch) + + # When the action queue is depleted, populate it again by querying the policy. + if len(self._queues[ACTION]) == 0: + actions = self.predict_action_chunk(batch) + + if self.config.n_action_repeats > 1: + for _ in range(self.config.n_action_repeats): + self._queues[ACTION].append(actions[0]) + else: + # Action queue is (n_action_steps, batch_size, action_dim), so we transpose the action. + self._queues[ACTION].extend(actions[: self.config.n_action_steps]) + + action = self._queues[ACTION].popleft() + return action + + @torch.no_grad() + def plan(self, z: Tensor) -> Tensor: + """Plan sequence of actions using TD-MPC inference. + + Args: + z: (batch, latent_dim,) tensor for the initial state. + Returns: + (horizon, batch, action_dim,) tensor for the planned trajectory of actions. + """ + device = get_device_from_parameters(self) + + batch_size = z.shape[0] + + # Sample Nπ trajectories from the policy. + pi_actions = torch.empty( + self.config.horizon, + self.config.n_pi_samples, + batch_size, + self.config.action_feature.shape[0], + device=device, + ) + if self.config.n_pi_samples > 0: + _z = einops.repeat(z, "b d -> n b d", n=self.config.n_pi_samples) + for t in range(self.config.horizon): + # Note: Adding a small amount of noise here doesn't hurt during inference and may even be + # helpful for CEM. + pi_actions[t] = self.model.pi(_z, self.config.min_std) + _z = self.model.latent_dynamics(_z, pi_actions[t]) + + # In the CEM loop we will need this for a call to estimate_value with the gaussian sampled + # trajectories. + z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples) + + # Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization + # algorithm. + # The initial mean and standard deviation for the cross-entropy method (CEM). + mean = torch.zeros( + self.config.horizon, batch_size, self.config.action_feature.shape[0], device=device + ) + # Maybe warm start CEM with the mean from the previous step. + if self._prev_mean is not None: + mean[:-1] = self._prev_mean[1:] + std = self.config.max_std * torch.ones_like(mean) + + for _ in range(self.config.cem_iterations): + # Randomly sample action trajectories for the gaussian distribution. + std_normal_noise = torch.randn( + self.config.horizon, + self.config.n_gaussian_samples, + batch_size, + self.config.action_feature.shape[0], + device=std.device, + ) + gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1) + + # Compute elite actions. + actions = torch.cat([gaussian_actions, pi_actions], dim=1) + value = self.estimate_value(z, actions).nan_to_num_(0) + elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch) + elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch) + # (horizon, n_elites, batch, action_dim) + elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1) + + # Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites. + max_value = elite_value.max(0, keepdim=True)[0] # (1, batch) + # The weighting is a softmax over trajectory values. Note that this is not the same as the usage + # of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This + # makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²). + score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value)) + score /= score.sum(axis=0, keepdim=True) + # (horizon, batch, action_dim) + _mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) + _std = torch.sqrt( + torch.sum( + einops.rearrange(score, "n b -> n b 1") + * (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2, + dim=1, + ) + ) + # Update mean with an exponential moving average, and std with a direct replacement. + mean = ( + self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean + ) + std = _std.clamp_(self.config.min_std, self.config.max_std) + + # Keep track of the mean for warm-starting subsequent steps. + self._prev_mean = mean + + # Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax + # scores from the last iteration. + actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)] + + return actions + + @torch.no_grad() + def estimate_value(self, z: Tensor, actions: Tensor): + """Estimates the value of a trajectory as per eqn 4 of the FOWM paper. + + Args: + z: (batch, latent_dim) tensor of initial latent states. + actions: (horizon, batch, action_dim) tensor of action trajectories. + Returns: + (batch,) tensor of values. + """ + # Initialize return and running discount factor. + G, running_discount = 0, 1 + # Iterate over the actions in the trajectory to simulate the trajectory using the latent dynamics + # model. Keep track of return. + for t in range(actions.shape[0]): + # We will compute the reward in a moment. First compute the uncertainty regularizer from eqn 4 + # of the FOWM paper. + if self.config.uncertainty_regularizer_coeff > 0: + regularization = -( + self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0) + ) + else: + regularization = 0 + # Estimate the next state (latent) and reward. + z, reward = self.model.latent_dynamics_and_reward(z, actions[t]) + # Update the return and running discount. + G += running_discount * (reward + regularization) + running_discount *= self.config.discount + # Add the estimated value of the final state (using the minimum for a conservative estimate). + # Do so by predicting the next action, then taking a minimum over the ensemble of state-action value + # estimators. + # Note: This small amount of added noise seems to help a bit at inference time as observed by success + # metrics over 50 episodes of xarm_lift_medium_replay. + next_action = self.model.pi(z, self.config.min_std) # (batch, action_dim) + terminal_values = self.model.Qs(z, next_action) # (ensemble, batch) + # Randomly choose 2 of the Qs for terminal value estimation (as in App C. of the FOWM paper). + if self.config.q_ensemble_size > 2: + G += ( + running_discount + * torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[ + 0 + ] + ) + else: + G += running_discount * torch.min(terminal_values, dim=0)[0] + # Finally, also regularize the terminal value. + if self.config.uncertainty_regularizer_coeff > 0: + G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0) + return G + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: + """Run the batch through the model and compute the loss. + + Returns a dictionary with loss as a tensor, and other information as native floats. + """ + device = get_device_from_parameters(self) + + batch = self.normalize_inputs(batch) + if self.config.image_features: + batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch[OBS_IMAGE] = batch[next(iter(self.config.image_features))] + batch = self.normalize_targets(batch) + + info = {} + + # (b, t) -> (t, b) + for key in batch: + if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1: + batch[key] = batch[key].transpose(1, 0) + + action = batch[ACTION] # (t, b, action_dim) + reward = batch[REWARD] # (t, b) + observations = {k: v for k, v in batch.items() if k.startswith("observation.")} + + # Apply random image augmentations. + if self.config.image_features and self.config.max_random_shift_ratio > 0: + observations[OBS_IMAGE] = flatten_forward_unflatten( + partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio), + observations[OBS_IMAGE], + ) + + # Get the current observation for predicting trajectories, and all future observations for use in + # the latent consistency loss and TD loss. + current_observation, next_observations = {}, {} + for k in observations: + current_observation[k] = observations[k][0] + next_observations[k] = observations[k][1:] + horizon, batch_size = next_observations[ + OBS_IMAGE if self.config.image_features else OBS_ENV_STATE + ].shape[:2] + + # Run latent rollout using the latent dynamics model and policy model. + # Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action + # gives us a next `z`. + batch_size = batch["index"].shape[0] + z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device) + z_preds[0] = self.model.encode(current_observation) + reward_preds = torch.empty_like(reward, device=device) + for t in range(horizon): + z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t]) + + # Compute Q and V value predictions based on the latent rollout. + q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch) + v_preds = self.model.V(z_preds[:-1]) + info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()}) + + # Compute various targets with stopgrad. + with torch.no_grad(): + # Latent state consistency targets. + z_targets = self.model_target.encode(next_observations) + # State-action value targets (or TD targets) as in eqn 3 of the FOWM. Unlike TD-MPC which uses the + # learned state-action value function in conjunction with the learned policy: Q(z, π(z)), FOWM + # uses a learned state value function: V(z). This means the TD targets only depend on in-sample + # actions (not actions estimated by π). + # Note: Here we do not use self.model_target, but self.model. This is to follow the original code + # and the FOWM paper. + q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations)) + # From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we + # are using them to compute loss for V. + v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True) + + # Compute losses. + # Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the + # future have less impact on the loss. Note: unsqueeze will let us broadcast to (seq, batch). + temporal_loss_coeffs = torch.pow( + self.config.temporal_decay_coeff, torch.arange(horizon, device=device) + ).unsqueeze(-1) + # Compute consistency loss as MSE loss between latents predicted from the rollout and latents + # predicted from the (target model's) observation encoder. + consistency_loss = ( + ( + temporal_loss_coeffs + * F.mse_loss(z_preds[1:], z_targets, reduction="none").mean(dim=-1) + # `z_preds` depends on the current observation and the actions. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + # `z_targets` depends on the next observation. + * ~batch["observation.state_is_pad"][1:] + ) + .sum(0) + .mean() + ) + # Compute the reward loss as MSE loss between rewards predicted from the rollout and the dataset + # rewards. + reward_loss = ( + ( + temporal_loss_coeffs + * F.mse_loss(reward_preds, reward, reduction="none") + * ~batch["next.reward_is_pad"] + # `reward_preds` depends on the current observation and the actions. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + ) + .sum(0) + .mean() + ) + # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. + q_value_loss = ( + ( + temporal_loss_coeffs + * F.mse_loss( + q_preds_ensemble, + einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]), + reduction="none", + ).sum(0) # sum over ensemble + # `q_preds_ensemble` depends on the first observation and the actions. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + # q_targets depends on the reward and the next observations. + * ~batch["next.reward_is_pad"] + * ~batch["observation.state_is_pad"][1:] + ) + .sum(0) + .mean() + ) + # Compute state value loss as in eqn 3 of FOWM. + diff = v_targets - v_preds + # Expectile loss penalizes: + # - `v_preds < v_targets` with weighting `expectile_weight` + # - `v_preds >= v_targets` with weighting `1 - expectile_weight` + raw_v_value_loss = torch.where( + diff > 0, self.config.expectile_weight, (1 - self.config.expectile_weight) + ) * (diff**2) + v_value_loss = ( + ( + temporal_loss_coeffs + * raw_v_value_loss + # `v_targets` depends on the first observation and the actions, as does `v_preds`. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + ) + .sum(0) + .mean() + ) + + # Calculate the advantage weighted regression loss for π as detailed in FOWM 3.1. + # We won't need these gradients again so detach. + z_preds = z_preds.detach() + # Use stopgrad for the advantage calculation. + with torch.no_grad(): + advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V( + z_preds[:-1] + ) + info["advantage"] = advantage[0] + # (t, b) + exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0) + action_preds = self.model.pi(z_preds[:-1]) # (t, b, a) + # Calculate the MSE between the actions and the action predictions. + # Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation + # gaussian) and sums over the action dimension. Computing the (negative) log probability amounts to + # multiplying the MSE by 0.5 and adding a constant offset (the log(2*pi)/2 term, times the action + # dimension). Here we drop the constant offset as it doesn't change the optimization step, and we drop + # the 0.5 as we instead make a configuration parameter for it (see below where we compute the total + # loss). + mse = F.mse_loss(action_preds, action, reduction="none").sum(-1) # (t, b) + # NOTE: The original implementation does not take the sum over the temporal dimension like with the + # other losses. + # TODO(alexander-soare): Take the sum over the temporal dimension and check that training still works + # as well as expected. + pi_loss = ( + exp_advantage + * mse + * temporal_loss_coeffs + # `action_preds` depends on the first observation and the actions. + * ~batch["observation.state_is_pad"][0] + * ~batch["action_is_pad"] + ).mean() + + loss = ( + self.config.consistency_coeff * consistency_loss + + self.config.reward_coeff * reward_loss + + self.config.value_coeff * q_value_loss + + self.config.value_coeff * v_value_loss + + self.config.pi_coeff * pi_loss + ) + + info.update( + { + "consistency_loss": consistency_loss.item(), + "reward_loss": reward_loss.item(), + "Q_value_loss": q_value_loss.item(), + "V_value_loss": v_value_loss.item(), + "pi_loss": pi_loss.item(), + "sum_loss": loss.item() * self.config.horizon, + } + ) + + # Undo (b, t) -> (t, b). + for key in batch: + if isinstance(batch[key], torch.Tensor) and batch[key].ndim > 1: + batch[key] = batch[key].transpose(1, 0) + + return loss, info + + def update(self): + """Update the target model's parameters with an EMA step.""" + # Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA + # update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code + # we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995) + update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum) + + +class TDMPCTOLD(nn.Module): + """Task-Oriented Latent Dynamics (TOLD) model used in TD-MPC.""" + + def __init__(self, config: TDMPCConfig): + super().__init__() + self.config = config + self._encoder = TDMPCObservationEncoder(config) + self._dynamics = nn.Sequential( + nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Sigmoid(), + ) + self._reward = nn.Sequential( + nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, 1), + ) + self._pi = nn.Sequential( + nn.Linear(config.latent_dim, config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Mish(), + nn.Linear(config.mlp_dim, config.action_feature.shape[0]), + ) + self._Qs = nn.ModuleList( + [ + nn.Sequential( + nn.Linear(config.latent_dim + config.action_feature.shape[0], config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Tanh(), + nn.Linear(config.mlp_dim, config.mlp_dim), + nn.ELU(), + nn.Linear(config.mlp_dim, 1), + ) + for _ in range(config.q_ensemble_size) + ] + ) + self._V = nn.Sequential( + nn.Linear(config.latent_dim, config.mlp_dim), + nn.LayerNorm(config.mlp_dim), + nn.Tanh(), + nn.Linear(config.mlp_dim, config.mlp_dim), + nn.ELU(), + nn.Linear(config.mlp_dim, 1), + ) + self._init_weights() + + def _init_weights(self): + """Initialize model weights. + + Orthogonal initialization for all linear and convolutional layers' weights (apart from final layers + of reward network and Q networks which get zero initialization). + Zero initialization for all linear and convolutional layers' biases. + """ + + def _apply_fn(m): + if isinstance(m, nn.Linear): + nn.init.orthogonal_(m.weight.data) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + gain = nn.init.calculate_gain("relu") + nn.init.orthogonal_(m.weight.data, gain) + if m.bias is not None: + nn.init.zeros_(m.bias) + + self.apply(_apply_fn) + for m in [self._reward, *self._Qs]: + assert isinstance(m[-1], nn.Linear), ( + "Sanity check. The last linear layer needs 0 initialization on weights." + ) + nn.init.zeros_(m[-1].weight) + nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure + + def encode(self, obs: dict[str, Tensor]) -> Tensor: + """Encodes an observation into its latent representation.""" + return self._encoder(obs) + + def latent_dynamics_and_reward(self, z: Tensor, a: Tensor) -> tuple[Tensor, Tensor]: + """Predict the next state's latent representation and the reward given a current latent and action. + + Args: + z: (*, latent_dim) tensor for the current state's latent representation. + a: (*, action_dim) tensor for the action to be applied. + Returns: + A tuple containing: + - (*, latent_dim) tensor for the next state's latent representation. + - (*,) tensor for the estimated reward. + """ + x = torch.cat([z, a], dim=-1) + return self._dynamics(x), self._reward(x).squeeze(-1) + + def latent_dynamics(self, z: Tensor, a: Tensor) -> Tensor: + """Predict the next state's latent representation given a current latent and action. + + Args: + z: (*, latent_dim) tensor for the current state's latent representation. + a: (*, action_dim) tensor for the action to be applied. + Returns: + (*, latent_dim) tensor for the next state's latent representation. + """ + x = torch.cat([z, a], dim=-1) + return self._dynamics(x) + + def pi(self, z: Tensor, std: float = 0.0) -> Tensor: + """Samples an action from the learned policy. + + The policy can also have added (truncated) Gaussian noise injected for encouraging exploration when + generating rollouts for online training. + + Args: + z: (*, latent_dim) tensor for the current state's latent representation. + std: The standard deviation of the injected noise. + Returns: + (*, action_dim) tensor for the sampled action. + """ + action = torch.tanh(self._pi(z)) + if std > 0: + std = torch.ones_like(action) * std + action += torch.randn_like(action) * std + return action + + def V(self, z: Tensor) -> Tensor: # noqa: N802 + """Predict state value (V). + + Args: + z: (*, latent_dim) tensor for the current state's latent representation. + Returns: + (*,) tensor of estimated state values. + """ + return self._V(z).squeeze(-1) + + def Qs(self, z: Tensor, a: Tensor, return_min: bool = False) -> Tensor: # noqa: N802 + """Predict state-action value for all of the learned Q functions. + + Args: + z: (*, latent_dim) tensor for the current state's latent representation. + a: (*, action_dim) tensor for the action to be applied. + return_min: Set to true for implementing the detail in App. C of the FOWM paper: randomly select + 2 of the Qs and return the minimum + Returns: + (q_ensemble, *) tensor for the value predictions of each learned Q function in the ensemble OR + (*,) tensor if return_min=True. + """ + x = torch.cat([z, a], dim=-1) + if not return_min: + return torch.stack([q(x).squeeze(-1) for q in self._Qs], dim=0) + else: + if len(self._Qs) > 2: # noqa: SIM108 + Qs = [self._Qs[i] for i in np.random.choice(len(self._Qs), size=2)] + else: + Qs = self._Qs + return torch.stack([q(x).squeeze(-1) for q in Qs], dim=0).min(dim=0)[0] + + +class TDMPCObservationEncoder(nn.Module): + """Encode image and/or state vector observations.""" + + def __init__(self, config: TDMPCConfig): + """ + Creates encoders for pixel and/or state modalities. + TODO(alexander-soare): The original work allows for multiple images by concatenating them along the + channel dimension. Re-implement this capability. + """ + super().__init__() + self.config = config + + if config.image_features: + self.image_enc_layers = nn.Sequential( + nn.Conv2d( + next(iter(config.image_features.values())).shape[0], + config.image_encoder_hidden_dim, + 7, + stride=2, + ), + nn.ReLU(), + nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2), + nn.ReLU(), + nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), + nn.ReLU(), + nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), + nn.ReLU(), + ) + dummy_shape = (1, *next(iter(config.image_features.values())).shape) + out_shape = get_output_shape(self.image_enc_layers, dummy_shape)[1:] + self.image_enc_layers.extend( + nn.Sequential( + nn.Flatten(), + nn.Linear(np.prod(out_shape), config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Sigmoid(), + ) + ) + + if config.robot_state_feature: + self.state_enc_layers = nn.Sequential( + nn.Linear(config.robot_state_feature.shape[0], config.state_encoder_hidden_dim), + nn.ELU(), + nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Sigmoid(), + ) + + if config.env_state_feature: + self.env_state_enc_layers = nn.Sequential( + nn.Linear(config.env_state_feature.shape[0], config.state_encoder_hidden_dim), + nn.ELU(), + nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Sigmoid(), + ) + + def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: + """Encode the image and/or state vector. + + Each modality is encoded into a feature vector of size (latent_dim,) and then a uniform mean is taken + over all features. + """ + feat = [] + # NOTE: Order of observations matters here. + if self.config.image_features: + feat.append( + flatten_forward_unflatten( + self.image_enc_layers, obs_dict[next(iter(self.config.image_features))] + ) + ) + if self.config.env_state_feature: + feat.append(self.env_state_enc_layers(obs_dict[OBS_ENV_STATE])) + if self.config.robot_state_feature: + feat.append(self.state_enc_layers(obs_dict[OBS_STATE])) + return torch.stack(feat, dim=0).mean(0) + + +def random_shifts_aug(x: Tensor, max_random_shift_ratio: float) -> Tensor: + """Randomly shifts images horizontally and vertically. + + Adapted from https://github.com/facebookresearch/drqv2 + """ + b, _, h, w = x.size() + assert h == w, "non-square images not handled yet" + pad = int(round(max_random_shift_ratio * h)) + x = F.pad(x, tuple([pad] * 4), "replicate") + eps = 1.0 / (h + 2 * pad) + arange = torch.linspace( + -1.0 + eps, + 1.0 - eps, + h + 2 * pad, + device=x.device, + dtype=torch.float32, + )[:h] + arange = einops.repeat(arange, "w -> h w 1", h=h) + base_grid = torch.cat([arange, arange.transpose(1, 0)], dim=2) + base_grid = einops.repeat(base_grid, "h w c -> b h w c", b=b) + # A random shift in units of pixels and within the boundaries of the padding. + shift = torch.randint( + 0, + 2 * pad + 1, + size=(b, 1, 1, 2), + device=x.device, + dtype=torch.float32, + ) + shift *= 2.0 / (h + 2 * pad) + grid = base_grid + shift + return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) + + +def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float): + """Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param.""" + for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True): + for (n_p_ema, p_ema), (n_p, p) in zip( + ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True + ): + assert n_p_ema == n_p, "Parameter names don't match for EMA model update" + if isinstance(p, dict): + raise RuntimeError("Dict parameter not supported") + if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad: + # Copy BatchNorm parameters, and non-trainable parameters directly. + p_ema.copy_(p.to(dtype=p_ema.dtype).data) + with torch.no_grad(): + p_ema.mul_(alpha) + p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha) + + +def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: + """Helper to temporarily flatten extra dims at the start of the image tensor. + + Args: + fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return + (B, *), where * is any number of dimensions. + image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions, generally + different from *. + Returns: + A return value from the callable reshaped to (**, *). + """ + if image_tensor.ndim == 4: + return fn(image_tensor) + start_dims = image_tensor.shape[:-3] + inp = torch.flatten(image_tensor, end_dim=-4) + flat_out = fn(inp) + return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) diff --git a/src/lerobot/robots/viperx/README.md b/src/lerobot/robots/viperx/README.md index e69de29bb..445368e7a 100644 --- a/src/lerobot/robots/viperx/README.md +++ b/src/lerobot/robots/viperx/README.md @@ -0,0 +1,182 @@ +This tutorial explains how to use [Aloha and Aloha 2 stationary](https://www.trossenrobotics.com/aloha-stationary) with LeRobot. + +## Setup + +Follow the [documentation from Trossen Robotics](https://docs.trossenrobotics.com/aloha_docs/2.0/getting_started/stationary/hardware_setup.html) for setting up the hardware and plugging the 4 arms and 4 cameras to your computer. + + +## Install LeRobot + +On your computer: + +1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install): +```bash +mkdir -p ~/miniconda3 +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh +bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 +rm ~/miniconda3/miniconda.sh +~/miniconda3/bin/conda init bash +``` + +2. Restart shell or `source ~/.bashrc` + +3. Create and activate a fresh conda environment for lerobot +```bash +conda create -y -n lerobot python=3.10 && conda activate lerobot +``` + +4. Clone LeRobot: +```bash +git clone https://github.com/huggingface/lerobot.git ~/lerobot +``` + +5. When using `miniconda`, install `ffmpeg` in your environment: +```bash +conda install ffmpeg -c conda-forge +``` + +6. Install LeRobot with dependencies for the Aloha motors (dynamixel) and cameras (intelrealsense): +```bash +cd ~/lerobot && pip install -e ".[dynamixel, intelrealsense]" +``` + +## Teleoperate + +**/!\ FOR SAFETY, READ THIS /!\** +Teleoperation consists in manually operating the leader arms to move the follower arms. Importantly: +1. Make sure your leader arms are in the same position as the follower arms, so that the follower arms don't move too fast to match the leader arms, +2. Our code assumes that your robot has been assembled following Trossen Robotics instructions. This allows us to skip calibration, as we use the pre-defined calibration files in `.cache/calibration/aloha_default`. If you replace a motor, make sure you follow the exact instructions from Trossen Robotics. + +By running the following code, you can start your first **SAFE** teleoperation: + +> **NOTE:** To visualize the data, enable `--control.display_data=true`. This streams the data using `rerun`. + +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=aloha \ + --robot.max_relative_target=5 \ + --control.type=teleoperate +``` + +By adding `--robot.max_relative_target=5`, we override the default value for `max_relative_target` defined in [`AlohaRobotConfig`](lerobot/robot_devices/robots/configs.py). It is expected to be `5` to limit the magnitude of the movement for more safety, but the teleoperation won't be smooth. When you feel confident, you can disable this limit by adding `--robot.max_relative_target=null` to the command line: +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=aloha \ + --robot.max_relative_target=null \ + --control.type=teleoperate +``` + +## Record a dataset + +Once you're familiar with teleoperation, you can record your first dataset with Aloha. + +If you want to use the Hugging Face hub features for uploading your dataset and you haven't previously done it, make sure you've logged in using a write-access token, which can be generated from the [Hugging Face settings](https://huggingface.co/settings/tokens): +```bash +huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential +``` + +Store your Hugging Face repository name in a variable to run these commands: +```bash +HF_USER=$(huggingface-cli whoami | head -n 1) +echo $HF_USER +``` + +Record 2 episodes and upload your dataset to the hub: +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=aloha \ + --robot.max_relative_target=null \ + --control.type=record \ + --control.fps=30 \ + --control.single_task="Grasp a lego block and put it in the bin." \ + --control.repo_id=${HF_USER}/aloha_test \ + --control.tags='["tutorial"]' \ + --control.warmup_time_s=5 \ + --control.episode_time_s=30 \ + --control.reset_time_s=30 \ + --control.num_episodes=2 \ + --control.push_to_hub=true +``` + +## Visualize a dataset + +If you uploaded your dataset to the hub with `--control.push_to_hub=true`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by: +```bash +echo ${HF_USER}/aloha_test +``` + +If you didn't upload with `--control.push_to_hub=false`, you can also visualize it locally with: +```bash +python -m lerobot.scripts.visualize_dataset_html \ + --repo-id ${HF_USER}/aloha_test +``` + +## Replay an episode + +**/!\ FOR SAFETY, READ THIS /!\** +Replay consists in automatically replaying the sequence of actions (i.e. goal positions for your motors) recorded in a given dataset episode. Make sure the current initial position of your robot is similar to the one in your episode, so that your follower arms don't move too fast to go to the first goal positions. For safety, you might want to add `--robot.max_relative_target=5` to your command line as explained above. + +Now try to replay the first episode on your robot: +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=aloha \ + --robot.max_relative_target=null \ + --control.type=replay \ + --control.fps=30 \ + --control.repo_id=${HF_USER}/aloha_test \ + --control.episode=0 +``` + +## Train a policy + +To train a policy to control your robot, use the [`python -m lerobot.scripts.train`](../src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: +```bash +python -m lerobot.scripts.train \ + --dataset.repo_id=${HF_USER}/aloha_test \ + --policy.type=act \ + --output_dir=outputs/train/act_aloha_test \ + --job_name=act_aloha_test \ + --policy.device=cuda \ + --wandb.enable=true +``` + +Let's explain it: +1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/aloha_test`. +2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](../src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset. +4. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon. +5. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`. + +For more information on the `train` script see the previous tutorial: [`examples/4_train_policy_with_script.md`](../examples/4_train_policy_with_script.md) + +Training should take several hours. You will find checkpoints in `outputs/train/act_aloha_test/checkpoints`. + +## Evaluate your policy + +You can use the `record` function from [`lerobot/scripts/control_robot.py`](../src/lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes: +```bash +python lerobot/scripts/control_robot.py \ + --robot.type=aloha \ + --control.type=record \ + --control.fps=30 \ + --control.single_task="Grasp a lego block and put it in the bin." \ + --control.repo_id=${HF_USER}/eval_act_aloha_test \ + --control.tags='["tutorial"]' \ + --control.warmup_time_s=5 \ + --control.episode_time_s=30 \ + --control.reset_time_s=30 \ + --control.num_episodes=10 \ + --control.push_to_hub=true \ + --control.policy.path=outputs/train/act_aloha_test/checkpoints/last/pretrained_model \ + --control.num_image_writer_processes=1 +``` + +As you can see, it's almost the same command as previously used to record your training dataset. Two things changed: +1. There is an additional `--control.policy.path` argument which indicates the path to your policy checkpoint with (e.g. `outputs/train/eval_act_aloha_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `${HF_USER}/act_aloha_test`). +2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `${HF_USER}/eval_act_aloha_test`). +3. We use `--control.num_image_writer_processes=1` instead of the default value (`0`). On our computer, using a dedicated process to write images from the 4 cameras on disk allows to reach constant 30 fps during inference. Feel free to explore different values for `--control.num_image_writer_processes`. + +## More + +Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth explanation. + +If you have any question or need help, please reach out on Discord in the channel `#aloha-arm`. diff --git a/src/lerobot/scripts/eval.py b/src/lerobot/scripts/eval.py index e69de29bb..d85ac27b3 100644 --- a/src/lerobot/scripts/eval.py +++ b/src/lerobot/scripts/eval.py @@ -0,0 +1,506 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluate a policy on an environment by running rollouts and computing metrics. + +Usage examples: + +You want to evaluate a model from the hub (eg: https://huggingface.co/lerobot/diffusion_pusht) +for 10 episodes. + +``` +python -m lerobot.scripts.eval \ + --policy.path=lerobot/diffusion_pusht \ + --env.type=pusht \ + --eval.batch_size=10 \ + --eval.n_episodes=10 \ + --use_amp=false \ + --device=cuda +``` + +OR, you want to evaluate a model checkpoint from the LeRobot training script for 10 episodes. +``` +python -m lerobot.scripts.eval \ + --policy.path=outputs/train/diffusion_pusht/checkpoints/005000/pretrained_model \ + --env.type=pusht \ + --eval.batch_size=10 \ + --eval.n_episodes=10 \ + --use_amp=false \ + --device=cuda +``` + +Note that in both examples, the repo/folder should contain at least `config.json` and `model.safetensors` files. + +You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py +""" + +import json +import logging +import threading +import time +from contextlib import nullcontext +from copy import deepcopy +from dataclasses import asdict +from pathlib import Path +from pprint import pformat +from typing import Callable + +import einops +import gymnasium as gym +import numpy as np +import torch +from termcolor import colored +from torch import Tensor, nn +from tqdm import trange + +from lerobot.configs import parser +from lerobot.configs.eval import EvalPipelineConfig +from lerobot.envs.factory import make_env +from lerobot.envs.utils import add_envs_task, check_env_attributes_and_types, preprocess_observation +from lerobot.policies.factory import make_policy +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.utils import get_device_from_parameters +from lerobot.utils.io_utils import write_video +from lerobot.utils.random_utils import set_seed +from lerobot.utils.utils import ( + get_safe_torch_device, + init_logging, + inside_slurm, +) + + +def rollout( + env: gym.vector.VectorEnv, + policy: PreTrainedPolicy, + seeds: list[int] | None = None, + return_observations: bool = False, + render_callback: Callable[[gym.vector.VectorEnv], None] | None = None, +) -> dict: + """Run a batched policy rollout once through a batch of environments. + + Note that all environments in the batch are run until the last environment is done. This means some + data will probably need to be discarded (for environments that aren't the first one to be done). + + The return dictionary contains: + (optional) "observation": A dictionary of (batch, sequence + 1, *) tensors mapped to observation + keys. NOTE that this has an extra sequence element relative to the other keys in the + dictionary. This is because an extra observation is included for after the environment is + terminated or truncated. + "action": A (batch, sequence, action_dim) tensor of actions applied based on the observations (not + including the last observations). + "reward": A (batch, sequence) tensor of rewards received for applying the actions. + "success": A (batch, sequence) tensor of success conditions (the only time this can be True is upon + environment termination/truncation). + "done": A (batch, sequence) tensor of **cumulative** done conditions. For any given batch element, + the first True is followed by True's all the way till the end. This can be used for masking + extraneous elements from the sequences above. + + Args: + env: The batch of environments. + policy: The policy. Must be a PyTorch nn module. + seeds: The environments are seeded once at the start of the rollout. If provided, this argument + specifies the seeds for each of the environments. + return_observations: Whether to include all observations in the returned rollout data. Observations + are returned optionally because they typically take more memory to cache. Defaults to False. + render_callback: Optional rendering callback to be used after the environments are reset, and after + every step. + Returns: + The dictionary described above. + """ + assert isinstance(policy, nn.Module), "Policy must be a PyTorch nn module." + device = get_device_from_parameters(policy) + + # Reset the policy and environments. + policy.reset() + observation, info = env.reset(seed=seeds) + if render_callback is not None: + render_callback(env) + + all_observations = [] + all_actions = [] + all_rewards = [] + all_successes = [] + all_dones = [] + + step = 0 + # Keep track of which environments are done. + done = np.array([False] * env.num_envs) + max_steps = env.call("_max_episode_steps")[0] + progbar = trange( + max_steps, + desc=f"Running rollout with at most {max_steps} steps", + disable=inside_slurm(), # we dont want progress bar when we use slurm, since it clutters the logs + leave=False, + ) + check_env_attributes_and_types(env) + while not np.all(done): + # Numpy array to tensor and changing dictionary keys to LeRobot policy format. + observation = preprocess_observation(observation) + if return_observations: + all_observations.append(deepcopy(observation)) + + observation = { + key: observation[key].to(device, non_blocking=device.type == "cuda") for key in observation + } + + # Infer "task" from attributes of environments. + # TODO: works with SyncVectorEnv but not AsyncVectorEnv + observation = add_envs_task(env, observation) + + with torch.inference_mode(): + action = policy.select_action(observation) + + # Convert to CPU / numpy. + action = action.to("cpu").numpy() + assert action.ndim == 2, "Action dimensions should be (batch, action_dim)" + + # Apply the next action. + observation, reward, terminated, truncated, info = env.step(action) + if render_callback is not None: + render_callback(env) + + # VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't + # available of none of the envs finished. + if "final_info" in info: + successes = [info["is_success"] if info is not None else False for info in info["final_info"]] + else: + successes = [False] * env.num_envs + + # Keep track of which environments are done so far. + done = terminated | truncated | done + + all_actions.append(torch.from_numpy(action)) + all_rewards.append(torch.from_numpy(reward)) + all_dones.append(torch.from_numpy(done)) + all_successes.append(torch.tensor(successes)) + + step += 1 + running_success_rate = ( + einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean() + ) + progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"}) + progbar.update() + + # Track the final observation. + if return_observations: + observation = preprocess_observation(observation) + all_observations.append(deepcopy(observation)) + + # Stack the sequence along the first dimension so that we have (batch, sequence, *) tensors. + ret = { + "action": torch.stack(all_actions, dim=1), + "reward": torch.stack(all_rewards, dim=1), + "success": torch.stack(all_successes, dim=1), + "done": torch.stack(all_dones, dim=1), + } + if return_observations: + stacked_observations = {} + for key in all_observations[0]: + stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1) + ret["observation"] = stacked_observations + + if hasattr(policy, "use_original_modules"): + policy.use_original_modules() + + return ret + + +def eval_policy( + env: gym.vector.VectorEnv, + policy: PreTrainedPolicy, + n_episodes: int, + max_episodes_rendered: int = 0, + videos_dir: Path | None = None, + return_episode_data: bool = False, + start_seed: int | None = None, +) -> dict: + """ + Args: + env: The batch of environments. + policy: The policy. + n_episodes: The number of episodes to evaluate. + max_episodes_rendered: Maximum number of episodes to render into videos. + videos_dir: Where to save rendered videos. + return_episode_data: Whether to return episode data for online training. Incorporates the data into + the "episodes" key of the returned dictionary. + start_seed: The first seed to use for the first individual rollout. For all subsequent rollouts the + seed is incremented by 1. If not provided, the environments are not manually seeded. + Returns: + Dictionary with metrics and data regarding the rollouts. + """ + if max_episodes_rendered > 0 and not videos_dir: + raise ValueError("If max_episodes_rendered > 0, videos_dir must be provided.") + + if not isinstance(policy, PreTrainedPolicy): + raise ValueError( + f"Policy of type 'PreTrainedPolicy' is expected, but type '{type(policy)}' was provided." + ) + + start = time.time() + policy.eval() + + # Determine how many batched rollouts we need to get n_episodes. Note that if n_episodes is not evenly + # divisible by env.num_envs we end up discarding some data in the last batch. + n_batches = n_episodes // env.num_envs + int((n_episodes % env.num_envs) != 0) + + # Keep track of some metrics. + sum_rewards = [] + max_rewards = [] + all_successes = [] + all_seeds = [] + threads = [] # for video saving threads + n_episodes_rendered = 0 # for saving the correct number of videos + + # Callback for visualization. + def render_frame(env: gym.vector.VectorEnv): + # noqa: B023 + if n_episodes_rendered >= max_episodes_rendered: + return + n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs) + if isinstance(env, gym.vector.SyncVectorEnv): + ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023 + elif isinstance(env, gym.vector.AsyncVectorEnv): + # Here we must render all frames and discard any we don't need. + ep_frames.append(np.stack(env.call("render")[:n_to_render_now])) + + if max_episodes_rendered > 0: + video_paths: list[str] = [] + + if return_episode_data: + episode_data: dict | None = None + + # we dont want progress bar when we use slurm, since it clutters the logs + progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm()) + for batch_ix in progbar: + # Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout + # step. + if max_episodes_rendered > 0: + ep_frames: list[np.ndarray] = [] + + if start_seed is None: + seeds = None + else: + seeds = range( + start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs) + ) + rollout_data = rollout( + env, + policy, + seeds=list(seeds) if seeds else None, + return_observations=return_episode_data, + render_callback=render_frame if max_episodes_rendered > 0 else None, + ) + + # Figure out where in each rollout sequence the first done condition was encountered (results after + # this won't be included). + n_steps = rollout_data["done"].shape[1] + # Note: this relies on a property of argmax: that it returns the first occurrence as a tiebreaker. + done_indices = torch.argmax(rollout_data["done"].to(int), dim=1) + + # Make a mask with shape (batch, n_steps) to mask out rollout data after the first done + # (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step. + mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int() + # Extend metrics. + batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum") + sum_rewards.extend(batch_sum_rewards.tolist()) + batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max") + max_rewards.extend(batch_max_rewards.tolist()) + batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any") + all_successes.extend(batch_successes.tolist()) + if seeds: + all_seeds.extend(seeds) + else: + all_seeds.append(None) + + # FIXME: episode_data is either None or it doesn't exist + if return_episode_data: + this_episode_data = _compile_episode_data( + rollout_data, + done_indices, + start_episode_index=batch_ix * env.num_envs, + start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)), + fps=env.unwrapped.metadata["render_fps"], + ) + if episode_data is None: + episode_data = this_episode_data + else: + # Some sanity checks to make sure we are correctly compiling the data. + assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0] + assert episode_data["index"][-1] + 1 == this_episode_data["index"][0] + # Concatenate the episode data. + episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data} + + # Maybe render video for visualization. + if max_episodes_rendered > 0 and len(ep_frames) > 0: + batch_stacked_frames = np.stack(ep_frames, axis=1) # (b, t, *) + for stacked_frames, done_index in zip( + batch_stacked_frames, done_indices.flatten().tolist(), strict=False + ): + if n_episodes_rendered >= max_episodes_rendered: + break + + videos_dir.mkdir(parents=True, exist_ok=True) + video_path = videos_dir / f"eval_episode_{n_episodes_rendered}.mp4" + video_paths.append(str(video_path)) + thread = threading.Thread( + target=write_video, + args=( + str(video_path), + stacked_frames[: done_index + 1], # + 1 to capture the last observation + env.unwrapped.metadata["render_fps"], + ), + ) + thread.start() + threads.append(thread) + n_episodes_rendered += 1 + + progbar.set_postfix( + {"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"} + ) + + # Wait till all video rendering threads are done. + for thread in threads: + thread.join() + + # Compile eval info. + info = { + "per_episode": [ + { + "episode_ix": i, + "sum_reward": sum_reward, + "max_reward": max_reward, + "success": success, + "seed": seed, + } + for i, (sum_reward, max_reward, success, seed) in enumerate( + zip( + sum_rewards[:n_episodes], + max_rewards[:n_episodes], + all_successes[:n_episodes], + all_seeds[:n_episodes], + strict=True, + ) + ) + ], + "aggregated": { + "avg_sum_reward": float(np.nanmean(sum_rewards[:n_episodes])), + "avg_max_reward": float(np.nanmean(max_rewards[:n_episodes])), + "pc_success": float(np.nanmean(all_successes[:n_episodes]) * 100), + "eval_s": time.time() - start, + "eval_ep_s": (time.time() - start) / n_episodes, + }, + } + + if return_episode_data: + info["episodes"] = episode_data + + if max_episodes_rendered > 0: + info["video_paths"] = video_paths + + return info + + +def _compile_episode_data( + rollout_data: dict, done_indices: Tensor, start_episode_index: int, start_data_index: int, fps: float +) -> dict: + """Convenience function for `eval_policy(return_episode_data=True)` + + Compiles all the rollout data into a Hugging Face dataset. + + Similar logic is implemented when datasets are pushed to hub (see: `push_to_hub`). + """ + ep_dicts = [] + total_frames = 0 + for ep_ix in range(rollout_data["action"].shape[0]): + # + 2 to include the first done frame and the last observation frame. + num_frames = done_indices[ep_ix].item() + 2 + total_frames += num_frames + + # Here we do `num_frames - 1` as we don't want to include the last observation frame just yet. + ep_dict = { + "action": rollout_data["action"][ep_ix, : num_frames - 1], + "episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)), + "frame_index": torch.arange(0, num_frames - 1, 1), + "timestamp": torch.arange(0, num_frames - 1, 1) / fps, + "next.done": rollout_data["done"][ep_ix, : num_frames - 1], + "next.success": rollout_data["success"][ep_ix, : num_frames - 1], + "next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32), + } + + # For the last observation frame, all other keys will just be copy padded. + for k in ep_dict: + ep_dict[k] = torch.cat([ep_dict[k], ep_dict[k][-1:]]) + + for key in rollout_data["observation"]: + ep_dict[key] = rollout_data["observation"][key][ep_ix, :num_frames] + + ep_dicts.append(ep_dict) + + data_dict = {} + for key in ep_dicts[0]: + data_dict[key] = torch.cat([x[key] for x in ep_dicts]) + + data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1) + + return data_dict + + +@parser.wrap() +def eval_main(cfg: EvalPipelineConfig): + logging.info(pformat(asdict(cfg))) + + # Check device is available + device = get_safe_torch_device(cfg.policy.device, log=True) + + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + set_seed(cfg.seed) + + logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}") + + logging.info("Making environment.") + env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) + + logging.info("Making policy.") + + policy = make_policy( + cfg=cfg.policy, + env_cfg=cfg.env, + ) + policy.eval() + + with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): + info = eval_policy( + env, + policy, + cfg.eval.n_episodes, + max_episodes_rendered=10, + videos_dir=Path(cfg.output_dir) / "videos", + start_seed=cfg.seed, + ) + print(info["aggregated"]) + + # Save info + with open(Path(cfg.output_dir) / "eval_info.json", "w") as f: + json.dump(info, f, indent=2) + + env.close() + + logging.info("End of eval") + + +if __name__ == "__main__": + init_logging() + eval_main() diff --git a/src/lerobot/utils/control_utils.py b/src/lerobot/utils/control_utils.py index e69de29bb..4bcc241da 100644 --- a/src/lerobot/utils/control_utils.py +++ b/src/lerobot/utils/control_utils.py @@ -0,0 +1,215 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +######################################################################################## +# Utilities +######################################################################################## + + +import logging +import traceback +from contextlib import nullcontext +from copy import copy +from functools import cache + +import numpy as np +import torch +from deepdiff import DeepDiff +from termcolor import colored + +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.utils import DEFAULT_FEATURES +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.robots import Robot + + +def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None): + log_items = [] + if episode_index is not None: + log_items.append(f"ep:{episode_index}") + if frame_index is not None: + log_items.append(f"frame:{frame_index}") + + def log_dt(shortname, dt_val_s): + nonlocal log_items, fps + info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)" + if fps is not None: + actual_fps = 1 / dt_val_s + if actual_fps < fps - 1: + info_str = colored(info_str, "yellow") + log_items.append(info_str) + + # total step time displayed in milliseconds and its frequency + log_dt("dt", dt_s) + + # TODO(aliberts): move robot-specific logs logic in robot.print_logs() + if not robot.robot_type.startswith("stretch"): + for name in robot.leader_arms: + key = f"read_leader_{name}_pos_dt_s" + if key in robot.logs: + log_dt("dtRlead", robot.logs[key]) + + for name in robot.follower_arms: + key = f"write_follower_{name}_goal_pos_dt_s" + if key in robot.logs: + log_dt("dtWfoll", robot.logs[key]) + + key = f"read_follower_{name}_pos_dt_s" + if key in robot.logs: + log_dt("dtRfoll", robot.logs[key]) + + for name in robot.cameras: + key = f"read_camera_{name}_dt_s" + if key in robot.logs: + log_dt(f"dtR{name}", robot.logs[key]) + + info_str = " ".join(log_items) + logging.info(info_str) + + +@cache +def is_headless(): + """Detects if python is running without a monitor.""" + try: + import pynput # noqa + + return False + except Exception: + print( + "Error trying to import pynput. Switching to headless mode. " + "As a result, the video stream from the cameras won't be shown, " + "and you won't be able to change the control flow with keyboards. " + "For more info, see traceback below.\n" + ) + traceback.print_exc() + print() + return True + + +def predict_action( + observation: dict[str, np.ndarray], + policy: PreTrainedPolicy, + device: torch.device, + use_amp: bool, + task: str | None = None, + robot_type: str | None = None, +): + observation = copy(observation) + with ( + torch.inference_mode(), + torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), + ): + # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension + for name in observation: + observation[name] = torch.from_numpy(observation[name]) + if "image" in name: + observation[name] = observation[name].type(torch.float32) / 255 + observation[name] = observation[name].permute(2, 0, 1).contiguous() + observation[name] = observation[name].unsqueeze(0) + observation[name] = observation[name].to(device) + + observation["task"] = task if task else "" + observation["robot_type"] = robot_type if robot_type else "" + + # Compute the next action with the policy + # based on the current observation + action = policy.select_action(observation) + + # Remove batch dimension + action = action.squeeze(0) + + # Move to cpu, if not already the case + action = action.to("cpu") + + return action + + +def init_keyboard_listener(): + # Allow to exit early while recording an episode or resetting the environment, + # by tapping the right arrow key '->'. This might require a sudo permission + # to allow your terminal to monitor keyboard events. + events = {} + events["exit_early"] = False + events["rerecord_episode"] = False + events["stop_recording"] = False + + if is_headless(): + logging.warning( + "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." + ) + listener = None + return listener, events + + # Only import pynput if not in a headless environment + from pynput import keyboard + + def on_press(key): + try: + if key == keyboard.Key.right: + print("Right arrow key pressed. Exiting loop...") + events["exit_early"] = True + elif key == keyboard.Key.left: + print("Left arrow key pressed. Exiting loop and rerecord the last episode...") + events["rerecord_episode"] = True + events["exit_early"] = True + elif key == keyboard.Key.esc: + print("Escape key pressed. Stopping data recording...") + events["stop_recording"] = True + events["exit_early"] = True + except Exception as e: + print(f"Error handling key press: {e}") + + listener = keyboard.Listener(on_press=on_press) + listener.start() + + return listener, events + + +def sanity_check_dataset_name(repo_id, policy_cfg): + _, dataset_name = repo_id.split("/") + # either repo_id doesnt start with "eval_" and there is no policy + # or repo_id starts with "eval_" and there is a policy + + # Check if dataset_name starts with "eval_" but policy is missing + if dataset_name.startswith("eval_") and policy_cfg is None: + raise ValueError( + f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})." + ) + + # Check if dataset_name does not start with "eval_" but policy is provided + if not dataset_name.startswith("eval_") and policy_cfg is not None: + raise ValueError( + f"Your dataset name does not begin with 'eval_' ({dataset_name}), but a policy is provided ({policy_cfg.type})." + ) + + +def sanity_check_dataset_robot_compatibility( + dataset: LeRobotDataset, robot: Robot, fps: int, features: dict +) -> None: + fields = [ + ("robot_type", dataset.meta.robot_type, robot.robot_type), + ("fps", dataset.fps, fps), + ("features", dataset.features, {**features, **DEFAULT_FEATURES}), + ] + + mismatches = [] + for field, dataset_value, present_value in fields: + diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]) + if diff: + mismatches.append(f"{field}: expected {present_value}, got {dataset_value}") + + if mismatches: + raise ValueError( + "Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches) + ) diff --git a/tests/configs/test_plugin_loading.py b/tests/configs/test_plugin_loading.py index 1a8cceed7..957574eb4 100644 --- a/tests/configs/test_plugin_loading.py +++ b/tests/configs/test_plugin_loading.py @@ -5,15 +5,15 @@ from typing import Generator import pytest -from lerobot.common.envs.configs import EnvConfig from lerobot.configs.parser import PluginLoadError, load_plugin, parse_plugin_args, wrap +from lerobot.envs.configs import EnvConfig def create_plugin_code(*, base_class: str = "EnvConfig", plugin_name: str = "test_env") -> str: """Creates a dummy plugin module that implements its own EnvConfig subclass.""" return f""" from dataclasses import dataclass -from lerobot.common.envs.configs import {base_class} +from lerobot.envs.configs import {base_class} @{base_class}.register_subclass("{plugin_name}") @dataclass diff --git a/tests/datasets/test_image_transforms.py b/tests/datasets/test_image_transforms.py index 146a4dcd4..3ab93cb2c 100644 --- a/tests/datasets/test_image_transforms.py +++ b/tests/datasets/test_image_transforms.py @@ -21,7 +21,7 @@ from safetensors.torch import load_file from torchvision.transforms import v2 from torchvision.transforms.v2 import functional as F # noqa: N812 -from lerobot.common.datasets.transforms import ( +from lerobot.datasets.transforms import ( ImageTransformConfig, ImageTransforms, ImageTransformsConfig, @@ -29,11 +29,11 @@ from lerobot.common.datasets.transforms import ( SharpnessJitter, make_transform_from_config, ) -from lerobot.common.utils.random_utils import seeded_context from lerobot.scripts.visualize_image_transforms import ( save_all_transforms, save_each_transform, ) +from lerobot.utils.random_utils import seeded_context from tests.artifacts.image_transforms.save_image_transforms_to_safetensors import ARTIFACT_DIR from tests.utils import require_x86_64_kernel diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index b318abb4a..140e9dfb9 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -21,8 +21,8 @@ import torch from gymnasium.utils.env_checker import check_env import lerobot -from lerobot.common.envs.factory import make_env, make_env_config -from lerobot.common.envs.utils import preprocess_observation +from lerobot.envs.factory import make_env, make_env_config +from lerobot.envs.utils import preprocess_observation from tests.utils import require_env OBS_TYPES = ["state", "pixels", "pixels_agent_pos"] diff --git a/tests/fixtures/constants.py b/tests/fixtures/constants.py index 81b9be39b..0af499364 100644 --- a/tests/fixtures/constants.py +++ b/tests/fixtures/constants.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from lerobot.common.constants import HF_LEROBOT_HOME +from lerobot.constants import HF_LEROBOT_HOME LEROBOT_TEST_DIR = HF_LEROBOT_HOME / "_testing" DUMMY_REPO_ID = "dummy/repo" diff --git a/tests/optim/test_optimizers.py b/tests/optim/test_optimizers.py index 630353fca..4152c7f8d 100644 --- a/tests/optim/test_optimizers.py +++ b/tests/optim/test_optimizers.py @@ -14,11 +14,11 @@ import pytest import torch -from lerobot.common.constants import ( +from lerobot.constants import ( OPTIMIZER_PARAM_GROUPS, OPTIMIZER_STATE, ) -from lerobot.common.optim.optimizers import ( +from lerobot.optim.optimizers import ( AdamConfig, AdamWConfig, MultiAdamConfig, diff --git a/tests/policies/test_policies.py b/tests/policies/test_policies.py index 3a8c6a224..bd6c99801 100644 --- a/tests/policies/test_policies.py +++ b/tests/policies/test_policies.py @@ -23,23 +23,23 @@ import torch from safetensors.torch import load_file from lerobot import available_policies -from lerobot.configs.default import DatasetConfig -from lerobot.configs.train import TrainPipelineConfig -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.datasets.factory import make_dataset -from lerobot.datasets.utils import cycle, dataset_to_policy_features -from lerobot.envs.factory import make_env, make_env_config -from lerobot.envs.utils import preprocess_observation -from lerobot.optim.factory import make_optimizer_and_scheduler -from lerobot.policies.act.modeling_act import ACTTemporalEnsembler -from lerobot.policies.factory import ( +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.utils import cycle, dataset_to_policy_features +from lerobot.common.envs.factory import make_env, make_env_config +from lerobot.common.envs.utils import preprocess_observation +from lerobot.common.optim.factory import make_optimizer_and_scheduler +from lerobot.common.policies.act.modeling_act import ACTTemporalEnsembler +from lerobot.common.policies.factory import ( get_policy_class, make_policy, make_policy_config, ) -from lerobot.policies.normalize import Normalize, Unnormalize -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.utils.random_utils import seeded_context +from lerobot.common.policies.normalize import Normalize, Unnormalize +from lerobot.common.policies.pretrained import PreTrainedPolicy +from lerobot.common.utils.random_utils import seeded_context +from lerobot.configs.default import DatasetConfig +from lerobot.configs.train import TrainPipelineConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel