diff --git a/docs/source/libero.mdx b/docs/source/libero.mdx index 488c02ce0..661ecc566 100644 --- a/docs/source/libero.mdx +++ b/docs/source/libero.mdx @@ -124,3 +124,7 @@ python src/lerobot/scripts/train.py \ LeRobot uses MuJoCo for simulation. You need to set the rendering backend before training or evaluation: - `export MUJOCO_GL=egl` → for headless servers (e.g. HPC, cloud) + +## Reproducing π₀ and π₀.₅ results + +We can also reproduce the results of π₀ and π₀.₅ on the Libero benchmark by using the finetuned libero models. diff --git a/docs/source/pi0.mdx b/docs/source/pi0.mdx index 35247ce3a..2d8887dc7 100644 --- a/docs/source/pi0.mdx +++ b/docs/source/pi0.mdx @@ -35,7 +35,7 @@ As described by Physical Intelligence, while AI has achieved remarkable success 2. Apply the custom patches: ```bash - cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* \ + cp -r ./src/lerobot/policies/pi0/transformers_replace/* \ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))") ``` @@ -72,7 +72,7 @@ pip install transformers==4.53.2 To use π₀ in LeRobot, specify the policy type as: ```python -policy.type=pi0_openpi +policy.type=pi0 ``` ## Training @@ -82,7 +82,7 @@ For training π₀, you can use the standard LeRobot training script with the ap ```bash python src/lerobot/scripts/train.py \ --dataset.repo_id=your_dataset \ - --policy.type=pi0_openpi \ + --policy.type=pi0 \ --output_dir=./outputs/pi0_training \ --job_name=pi0_training \ --policy.pretrained_path=pepijn223/pi0_base_fp32 \ diff --git a/docs/source/pi05.mdx b/docs/source/pi05.mdx index 3ea4a1fb3..01320dc88 100644 --- a/docs/source/pi05.mdx +++ b/docs/source/pi05.mdx @@ -43,7 +43,7 @@ This diverse training mixture creates a "curriculum" that enables generalization 2. Apply the custom patches: ```bash - cp -r ./src/lerobot/policies/pi05_openpi/transformers_replace/* \ + cp -r ./src/lerobot/policies/pi05/transformers_replace/* \ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))") ``` @@ -72,7 +72,7 @@ pip install transformers==4.53.2 To use π₀.₅ in your LeRobot configuration, specify the policy type as: ```python -policy.type=pi05_openpi +policy.type=pi05 ``` ## Training @@ -84,7 +84,7 @@ Here's a complete training command for finetuning the base π₀.₅ model on yo ```bash python src/lerobot/scripts/train.py \ --dataset.repo_id=your_dataset \ - --policy.type=pi05_openpi \ + --policy.type=pi05 \ --output_dir=./outputs/pi0_training \ --job_name=pi0_training \ --policy.repo_id=pepijn223/pi05_base_fp32 \ diff --git a/pyproject.toml b/pyproject.toml index 45c4146f0..c5b3d0185 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -147,8 +147,7 @@ all = [ "lerobot[reachy2]", "lerobot[kinematics]", "lerobot[intelrealsense]", - "lerobot[pi0]", - "lerobot[pi05]", + "lerobot[pi]", "lerobot[smolvla]", "lerobot[hilserl]", "lerobot[async]", diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index db3aa4039..c0b12c121 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -14,10 +14,8 @@ from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig -from .pi0.configuration_pi0 import PI0Config as PI0Config -from .pi0.processor_pi0 import Pi0NewLineProcessor -from .pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig as PI0OpenPIConfig -from .pi05_openpi.configuration_pi05openpi import PI05OpenPIConfig as PI05OpenPIConfig +from .pi0.configuration_pi0openpi import PI0OpenPIConfig as PI0OpenPIConfig +from .pi05.configuration_pi05openpi import PI05OpenPIConfig as PI05OpenPIConfig from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig from .smolvla.processor_smolvla import SmolVLANewLineProcessor from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig @@ -26,7 +24,8 @@ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig __all__ = [ "ACTConfig", "DiffusionConfig", - "PI0Config", + "PI0OpenPIConfig", + "PI05OpenPIConfig", "SmolVLAConfig", "TDMPCConfig", "VQBeTConfig", diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index f178d0801..da66ac400 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -31,10 +31,9 @@ from lerobot.envs.configs import EnvConfig from lerobot.envs.utils import env_to_policy_features from lerobot.policies.act.configuration_act import ACTConfig from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig -from lerobot.policies.pi0.configuration_pi0 import PI0Config -from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig +from lerobot.policies.pi0.configuration_pi0openpi import PI0OpenPIConfig from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig -from lerobot.policies.pi05_openpi.configuration_pi05openpi import PI05OpenPIConfig +from lerobot.policies.pi05.configuration_pi05openpi import PI05OpenPIConfig from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig @@ -83,20 +82,16 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy return VQBeTPolicy - elif name == "pi0": - from lerobot.policies.pi0.modeling_pi0 import PI0Policy - - return PI0Policy elif name == "pi0fast": from lerobot.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy return PI0FASTPolicy - elif name == "pi0_openpi": - from lerobot.policies.pi0_openpi.modeling_pi0openpi import PI0OpenPIPolicy + elif name == "pi0": + from lerobot.policies.pi0.modeling_pi0openpi import PI0OpenPIPolicy return PI0OpenPIPolicy - elif name == "pi05_openpi": - from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy + elif name == "pi05": + from lerobot.policies.pi05.modeling_pi05openpi import PI05OpenPIPolicy return PI05OpenPIPolicy elif name == "sac": @@ -142,13 +137,11 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: return ACTConfig(**kwargs) elif policy_type == "vqbet": return VQBeTConfig(**kwargs) - elif policy_type == "pi0": - return PI0Config(**kwargs) elif policy_type == "pi0fast": return PI0FASTConfig(**kwargs) - elif policy_type == "pi0_openpi": + elif policy_type == "pi0": return PI0OpenPIConfig(**kwargs) - elif policy_type == "pi05_openpi": + elif policy_type == "pi05": return PI05OpenPIConfig(**kwargs) elif policy_type == "sac": return SACConfig(**kwargs) @@ -267,14 +260,6 @@ def make_pre_post_processors( dataset_stats=kwargs.get("dataset_stats"), ) - elif isinstance(policy_cfg, PI0Config): - from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors - - processors = make_pi0_pre_post_processors( - config=policy_cfg, - dataset_stats=kwargs.get("dataset_stats"), - ) - elif isinstance(policy_cfg, PI0FASTConfig): from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors diff --git a/src/lerobot/policies/pi0_openpi/README.md b/src/lerobot/policies/pi0/README.md similarity index 100% rename from src/lerobot/policies/pi0_openpi/README.md rename to src/lerobot/policies/pi0/README.md diff --git a/src/lerobot/policies/pi0_openpi/__init__.py b/src/lerobot/policies/pi0/__init__.py similarity index 100% rename from src/lerobot/policies/pi0_openpi/__init__.py rename to src/lerobot/policies/pi0/__init__.py diff --git a/src/lerobot/policies/pi0/configuration_pi0.py b/src/lerobot/policies/pi0/configuration_pi0.py deleted file mode 100644 index c9728e418..000000000 --- a/src/lerobot/policies/pi0/configuration_pi0.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass, field - -from lerobot.configs.policies import PreTrainedConfig -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.optim.optimizers import AdamWConfig -from lerobot.optim.schedulers import ( - CosineDecayWithWarmupSchedulerConfig, -) - - -@PreTrainedConfig.register_subclass("pi0") -@dataclass -class PI0Config(PreTrainedConfig): - # Input / output structure. - n_obs_steps: int = 1 - chunk_size: int = 50 - n_action_steps: int = 50 - - normalization_mapping: dict[str, NormalizationMode] = field( - default_factory=lambda: { - "VISUAL": NormalizationMode.IDENTITY, - "STATE": NormalizationMode.MEAN_STD, - "ACTION": NormalizationMode.MEAN_STD, - } - ) - - # Shorter state and action vectors will be padded - max_state_dim: int = 32 - max_action_dim: int = 32 - - # Image preprocessing - resize_imgs_with_padding: tuple[int, int] = (224, 224) - - # Add empty images. Used by pi0_aloha_sim which adds the empty - # left and right wrist cameras in addition to the top camera. - empty_cameras: int = 0 - - # Converts the joint and gripper values from the standard Aloha space to - # the space used by the pi internal runtime which was used to train the base model. - adapt_to_pi_aloha: bool = False - - # Converts joint dimensions to deltas with respect to the current state before passing to the model. - # Gripper dimensions will remain in absolute values. - use_delta_joint_actions_aloha: bool = False - - # Tokenizer - tokenizer_max_length: int = 48 - - # Projector - proj_width: int = 1024 - - # Decoding - num_steps: int = 10 - - # Attention utils - use_cache: bool = True - attention_implementation: str = "eager" # or fa2, flex - - # Finetuning settings - freeze_vision_encoder: bool = True - train_expert_only: bool = False - train_state_proj: bool = True - - # Training presets - optimizer_lr: float = 2.5e-5 - optimizer_betas: tuple[float, float] = (0.9, 0.95) - optimizer_eps: float = 1e-8 - optimizer_weight_decay: float = 1e-10 - - scheduler_warmup_steps: int = 1_000 - scheduler_decay_steps: int = 30_000 - scheduler_decay_lr: float = 2.5e-6 - - # TODO: Add EMA - - def __post_init__(self): - super().__post_init__() - - # TODO(Steven): Validate device and amp? in all policy configs? - """Input validation (not exhaustive).""" - if self.n_action_steps > self.chunk_size: - raise ValueError( - f"The chunk size is the upper bound for the number of action steps per model invocation. Got " - f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`." - ) - if self.n_obs_steps != 1: - raise ValueError( - f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`" - ) - - if self.use_delta_joint_actions_aloha: - raise NotImplementedError( - "`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot." - ) - - def validate_features(self) -> None: - # TODO: implement value error - # if not self.image_features and not self.env_state_feature: - # raise ValueError("You must provide at least one image or the environment state among the inputs.") - - for i in range(self.empty_cameras): - key = f"observation.images.empty_camera_{i}" - empty_camera = PolicyFeature( - type=FeatureType.VISUAL, - shape=(3, 480, 640), - ) - self.input_features[key] = empty_camera - - def get_optimizer_preset(self) -> AdamWConfig: - return AdamWConfig( - lr=self.optimizer_lr, - betas=self.optimizer_betas, - eps=self.optimizer_eps, - weight_decay=self.optimizer_weight_decay, - ) - - def get_scheduler_preset(self): - return CosineDecayWithWarmupSchedulerConfig( - peak_lr=self.optimizer_lr, - decay_lr=self.scheduler_decay_lr, - num_warmup_steps=self.scheduler_warmup_steps, - num_decay_steps=self.scheduler_decay_steps, - ) - - @property - def observation_delta_indices(self) -> None: - return None - - @property - def action_delta_indices(self) -> list: - return list(range(self.chunk_size)) - - @property - def reward_delta_indices(self) -> None: - return None diff --git a/src/lerobot/policies/pi0_openpi/configuration_pi0openpi.py b/src/lerobot/policies/pi0/configuration_pi0openpi.py similarity index 99% rename from src/lerobot/policies/pi0_openpi/configuration_pi0openpi.py rename to src/lerobot/policies/pi0/configuration_pi0openpi.py index 0bef9c5a1..7402b2b89 100644 --- a/src/lerobot/policies/pi0_openpi/configuration_pi0openpi.py +++ b/src/lerobot/policies/pi0/configuration_pi0openpi.py @@ -22,7 +22,7 @@ from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig -@PreTrainedConfig.register_subclass("pi0_openpi") +@PreTrainedConfig.register_subclass("pi0") @dataclass class PI0OpenPIConfig(PreTrainedConfig): # Model architecture diff --git a/src/lerobot/policies/pi0/conversion_scripts/benchmark.py b/src/lerobot/policies/pi0/conversion_scripts/benchmark.py deleted file mode 100644 index c1a488244..000000000 --- a/src/lerobot/policies/pi0/conversion_scripts/benchmark.py +++ /dev/null @@ -1,82 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - -from lerobot.configs.policies import PreTrainedConfig -from lerobot.datasets.lerobot_dataset import LeRobotDataset -from lerobot.policies.factory import make_policy - -torch.backends.cudnn.benchmark = True - - -def main(): - device = "cuda" - dataset_repo_id = "danaaubakirova/koch_test" - # model_name = "pi0_base" - # ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch" - ckpt_torch_dir = "lerobot/pi0" - - dataset = LeRobotDataset(dataset_repo_id, episodes=[0]) - - dataloader = torch.utils.data.DataLoader( - dataset, - num_workers=0, - batch_size=1, - ) - - batch = next(iter(dataloader)) - - # To device - for k in batch: - if isinstance(batch[k], torch.Tensor): - batch[k] = batch[k].to(device=device, dtype=torch.float32) - - cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir) - cfg.pretrained_path = ckpt_torch_dir - policy = make_policy(cfg, ds_meta=dataset.meta) - - # policy = torch.compile(policy, mode="reduce-overhead") - - warmup_iters = 10 - benchmark_iters = 30 - - # Warmup - for _ in range(warmup_iters): - torch.cuda.synchronize() - policy.select_action(batch) - policy.reset() - torch.cuda.synchronize() - - # Benchmark - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - - start_event.record() - for _ in range(benchmark_iters): - policy.select_action(batch) - policy.reset() - end_event.record() - - # Synchronize and measure time - torch.cuda.synchronize() - elapsed_time_ms = start_event.elapsed_time(end_event) - - avg_time_per_iter = elapsed_time_ms / benchmark_iters - print(f"Average execution time per iteration: {avg_time_per_iter:.3f} ms") - - -if __name__ == "__main__": - with torch.inference_mode(): - main() diff --git a/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py b/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py deleted file mode 100644 index c0c2e4816..000000000 --- a/src/lerobot/policies/pi0/conversion_scripts/compare_with_jax.py +++ /dev/null @@ -1,131 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import pickle -from pathlib import Path - -import torch - -from lerobot.configs.policies import PreTrainedConfig -from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata -from lerobot.policies.factory import make_policy - - -def display(tensor: torch.Tensor): - if tensor.dtype == torch.bool: - tensor = tensor.float() - print(f"Shape: {tensor.shape}") - print(f"Mean: {tensor.mean().item()}") - print(f"Std: {tensor.std().item()}") - print(f"Min: {tensor.min().item()}") - print(f"Max: {tensor.max().item()}") - - -def main(): - num_motors = 14 - device = "cuda" - # model_name = "pi0_aloha_towel" - model_name = "pi0_aloha_sim" - - if model_name == "pi0_aloha_towel": - dataset_repo_id = "lerobot/aloha_static_towel" - else: - dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human" - - ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch" - ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}" - save_dir = Path(f"../openpi/data/{model_name}/save") - - with open(save_dir / "example.pkl", "rb") as f: - example = pickle.load(f) - with open(save_dir / "outputs.pkl", "rb") as f: - outputs = pickle.load(f) - with open(save_dir / "noise.pkl", "rb") as f: - noise = pickle.load(f) - - with open(ckpt_jax_dir / "assets/norm_stats.json") as f: - norm_stats = json.load(f) - - # Override stats - dataset_meta = LeRobotDatasetMetadata(dataset_repo_id) - dataset_meta.stats["observation.state"]["mean"] = torch.tensor( - norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32 - ) - dataset_meta.stats["observation.state"]["std"] = torch.tensor( - norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32 - ) - - # Create LeRobot batch from Jax - batch = {} - for cam_key, uint_chw_array in example["images"].items(): - batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0 - batch["observation.state"] = torch.from_numpy(example["state"]) - batch["action"] = torch.from_numpy(outputs["actions"]) - batch["task"] = example["prompt"] - - if model_name == "pi0_aloha_towel": - del batch["observation.images.cam_low"] - elif model_name == "pi0_aloha_sim": - batch["observation.images.top"] = batch["observation.images.cam_high"] - del batch["observation.images.cam_high"] - - # Batchify - for key in batch: - if isinstance(batch[key], torch.Tensor): - batch[key] = batch[key].unsqueeze(0) - elif isinstance(batch[key], str): - batch[key] = [batch[key]] - else: - raise ValueError(f"{key}, {batch[key]}") - - # To device - for k in batch: - if isinstance(batch[k], torch.Tensor): - batch[k] = batch[k].to(device=device, dtype=torch.float32) - - noise = torch.from_numpy(noise).to(device=device, dtype=torch.float32) - - from lerobot import policies # noqa - - cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir) - cfg.pretrained_path = ckpt_torch_dir - policy = make_policy(cfg, dataset_meta) - - # loss_dict = policy.forward(batch, noise=noise, time=time_beta) - # loss_dict["loss"].backward() - # print("losses") - # display(loss_dict["losses_after_forward"]) - # print("pi_losses") - # display(pi_losses) - - actions = [] - for _ in range(50): - action = policy.select_action(batch, noise=noise) - actions.append(action) - - actions = torch.stack(actions, dim=1) - pi_actions = batch["action"] - print("actions") - display(actions) - print() - print("pi_actions") - display(pi_actions) - print("atol=3e-2", torch.allclose(actions, pi_actions, atol=3e-2)) - print("atol=2e-2", torch.allclose(actions, pi_actions, atol=2e-2)) - print("atol=1e-2", torch.allclose(actions, pi_actions, atol=1e-2)) - - -if __name__ == "__main__": - main() diff --git a/src/lerobot/policies/pi0/conversion_scripts/conversion_utils.py b/src/lerobot/policies/pi0/conversion_scripts/conversion_utils.py deleted file mode 100644 index 8835da31e..000000000 --- a/src/lerobot/policies/pi0/conversion_scripts/conversion_utils.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from transformers import GemmaConfig, PaliGemmaConfig - - -def get_paligemma_config(precision: str): - config = { - "image_token_index": None, - "pad_token_id": 0, - "bos_token_id": 2, - "eos_token_id": 1, - } - - # image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896} - - image_size = 224 # image_sizes[variant] - patch_size = 14 - num_image_tokens = (image_size**2) // (patch_size**2) - - config["image_token_index"] = 257152 - text_config = { - "vocab_size": 257152, - "num_hidden_layers": 18, - "num_key_value_heads": 1, - "head_dim": 256, - "torch_dtype": precision, - "hidden_size": 2048, - "hidden_activation": "gelu_pytorch_tanh", - "num_attention_heads": 8, - "intermediate_size": 16384, - "is_encoder_decoder": False, - } - vision_config = { - "torch_dtype": precision, - "image_size": image_size, - "patch_size": patch_size, - "num_image_tokens": num_image_tokens, - "hidden_size": 1152, - "intermediate_size": 4304, - "num_hidden_layers": 27, - "num_attention_heads": 16, - "projector_hidden_act": "gelu_fast", - "vision_use_head": False, - } - final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config) - return final_config - - -def get_gemma_config(precision: str): - config = { - "image_token_index": None, - "pad_token_id": 0, - "bos_token_id": 2, - "eos_token_id": 1, - } - - config["image_token_index"] = 257152 - text_config = { - "vocab_size": 257152, - "num_hidden_layers": 18, - "num_key_value_heads": 1, - "head_dim": 256, - "torch_dtype": precision, - "hidden_size": 1024, - "hidden_activation": "gelu_pytorch_tanh", - "num_attention_heads": 8, - "intermediate_size": 4096, - "is_encoder_decoder": False, - } - final_config = GemmaConfig() - final_config.update(text_config) - return final_config diff --git a/src/lerobot/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py b/src/lerobot/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py deleted file mode 100644 index 742c9ab3f..000000000 --- a/src/lerobot/policies/pi0/conversion_scripts/convert_pi0_to_hf_lerobot.py +++ /dev/null @@ -1,437 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Convert pi0 parameters from Jax to Pytorch - -Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment -and install the required libraries. - -```bash -cd ~/code/openpi -source .venv/bin/activate -``` - -Example downloading parameters: -```bash -python ->>> import openpi.shared.download as download ->>> path='s3://openpi-assets/checkpoints/pi0_base/params' ->>> download.maybe_download(path) -``` - -Converting pi0_base: -```python -python -m lerobot.policies.pi0.conversion_scripts.convert_pi0_to_hf_lerobot \ - --checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base/params \ - --output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch -``` - -```python -python -m lerobot.policies.pi0.conversion_scripts.convert_pi0_to_hf_lerobot \ - --checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params \ - --output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch -``` -""" - -import argparse -import pathlib - -import jax -import numpy as np -import orbax.checkpoint as ocp -import torch -from jax.sharding import SingleDeviceSharding - -from lerobot.policies.pi0.configuration_pi0 import PI0Config -from lerobot.policies.pi0.conversion_scripts.conversion_utils import ( - get_gemma_config, - get_paligemma_config, -) -from lerobot.policies.pi0.modeling_pi0 import PI0Policy - -PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16} - - -def slice_paligemma_state_dict(state_dict, config): - suffix = "/value" if "img/embedding/kernel/value" in state_dict else "" - - # fmt: off - # patch embeddings - state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop(f"img/embedding/kernel{suffix}").transpose( - 3, 2, 0, 1 - ) - state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop(f"img/embedding/bias{suffix}") - # positional embeddings - state_dict["paligemma.vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop(f"img/pos_embedding{suffix}").reshape( - -1, config.vision_config.hidden_size - ) - - # extract vision layers to be sliced at index 0. There are 27 layers in the base model. - encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}") - encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}") - encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}") - encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}") - - encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}") - encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}") - encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}") - encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}") - - encoderblock_attention_0_key_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}") - encoderblock_attention_0_key_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}") - encoderblock_attention_0_value_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}") - encoderblock_attention_0_value_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}") - encoderblock_attention_0_query_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}") - encoderblock_attention_0_query_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}") - encoderblock_attention_0_out_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}") - encoderblock_attention_0_out_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}") - - for i in range(config.vision_config.num_hidden_layers): - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose() - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i] - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose() - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i] - - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose() - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i] - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose() - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i] - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose() - state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1) - - state_dict["paligemma.vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop(f"img/Transformer/encoder_norm/scale{suffix}").transpose() - state_dict["paligemma.vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop(f"img/Transformer/encoder_norm/bias{suffix}") - - # multimodal projector - - state_dict['paligemma.multi_modal_projector.linear.weight'] = state_dict.pop(f"img/head/kernel{suffix}").transpose() - state_dict['paligemma.multi_modal_projector.linear.bias'] = state_dict.pop(f"img/head/bias{suffix}") - - # text decoder (gemma) - embedding_vector = state_dict.pop(f"llm/embedder/input_embedding{suffix}") - state_dict["paligemma.language_model.model.embed_tokens.weight"] = embedding_vector - - # pop the einsum attention + mlp representations. There are 18 layers in gemma-2b. - - llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}") - llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}") - llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}") - - llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}") - llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}") - # TODO verify correctness of layer norm loading - - llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}") - llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}") - - for i in range(config.text_config.num_hidden_layers): - # llm_attention_q_einsum[i].shape = (8, 2048, 256) - q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) - - state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped - - # llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256) - k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() - state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped - # llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256) - v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() - state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped - - # output projection. - - # llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048) - o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size) - - state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped - # mlp layers - gate_proj_weight = llm_mlp_gating_einsum[i, 0] - state_dict[f"paligemma.language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose() - up_proj_weight = llm_mlp_gating_einsum[i, 1] - state_dict[f"paligemma.language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose() - state_dict[f"paligemma.language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose() - state_dict[f"paligemma.language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i] - state_dict[f"paligemma.language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i] - - state_dict["paligemma.language_model.model.norm.weight"] = state_dict.pop(f"llm/final_norm/scale{suffix}") - state_dict["paligemma.language_model.lm_head.weight"] = embedding_vector # weights are tied. - - # fmt: on - expert_dict = {} - final_state_dict = {} - for key, value in state_dict.items(): - if key not in [ - f"llm/final_norm_1/scale{suffix}", - f"llm/layers/attn/attn_vec_einsum_1/w{suffix}", - f"llm/layers/attn/kv_einsum_1/w{suffix}", - f"llm/layers/attn/q_einsum_1/w{suffix}", - f"llm/layers/mlp_1/gating_einsum{suffix}", - f"llm/layers/mlp_1/linear{suffix}", - f"llm/layers/pre_attention_norm_1/scale{suffix}", - f"llm/layers/pre_ffw_norm_1/scale{suffix}", - ]: - final_state_dict[key] = torch.from_numpy(value) - else: - expert_dict[key] = value - - return final_state_dict, expert_dict - - -def slice_gemma_state_dict(state_dict, config, num_expert=1): - # fmt: off - # text decoder (gemma) - # no embedding vector, the expert just has the decoder layers - - embedding_vector = torch.zeros([config.vocab_size, config.hidden_size]) - state_dict["gemma_expert.model.embed_tokens.weight"] = embedding_vector - - # pop the einsum attention + mlp representations. There are 18 layers in gemma-2b. - - suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else "" - - llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}") - llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}") - llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}") - - llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}") - llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}") - # TODO verify correctness of layer norm loading - - llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}") - llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}") - - for i in range(config.num_hidden_layers): - q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size) - - state_dict[f"gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped - - k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose() - state_dict[f"gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped - v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose() - state_dict[f"gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped - - # output projection. - - # llm_attention_attn_vec_einsum[i].shape = (8, 256, 1024) - o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1,0)# .transpose(2, 0, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1, 0) - - state_dict[f"gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped - # mlp layers - gate_proj_weight = llm_mlp_gating_einsum[i, 0] - state_dict[f"gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose() - up_proj_weight = llm_mlp_gating_einsum[i, 1] - state_dict[f"gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose() - state_dict[f"gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose() - state_dict[f"gemma_expert.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i] - state_dict[f"gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i] - - state_dict["gemma_expert.model.norm.weight"] = state_dict.pop(f"llm/final_norm_{num_expert}/scale{suffix}") - state_dict["gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. (and zeros here) - - # fmt: on - final_state_dict = {} - for key, value in state_dict.items(): - if not isinstance(value, torch.Tensor): - final_state_dict[key] = torch.from_numpy(value) - else: - final_state_dict[key] = value - return final_state_dict - - -def flatten_for_memory(tree, parent_key=""): - out = {} - for k, v in tree.items(): - new_key = f"{parent_key}/{k}" if parent_key else k - if isinstance(v, dict): - out.update(flatten_for_memory(v, new_key)) - else: - out[new_key] = np.array(v) # Ensure conversion to np.array for consistency - return out - - -def flatten_for_npz(tree, parent_key=""): - out = {} - for k, v in tree.items(): - new_key = f"{parent_key}/{k}" if parent_key else k - if isinstance(v, dict): - out.update(flatten_for_npz(v, new_key)) - else: - # bf16/f32 here? - out[new_key] = np.array(v) - return out - - -def slice_initial_orbax_checkpoint(checkpoint_dir: str): - params_path = pathlib.Path(checkpoint_dir).resolve() - checkpointer = ocp.PyTreeCheckpointer() - - metadata = checkpointer.metadata(params_path) - print("Metadata keys:", list(metadata.keys())) - - params_name = "params" - - item = {params_name: metadata[params_name]} - device = jax.local_devices()[0] # Use the first local device - sharding = SingleDeviceSharding(device) - restored = checkpointer.restore( - params_path, - ocp.args.PyTreeRestore( - item=item, - restore_args=jax.tree_util.tree_map( - lambda _: ocp.ArrayRestoreArgs( - restore_type=jax.Array, # or np.ndarray, but bf16 is annoying about it - sharding=sharding, - ), - item, - ), - transforms={}, - ), - ) - params = restored[params_name] - - # get params for PaliGemma - pali_params = params["PaliGemma"] - del params["PaliGemma"] - pali_params_flat = flatten_for_npz(pali_params) - return {"paligemma_params": pali_params_flat, "projection_params": params} - - -def update_keys_with_prefix(d: dict, prefix: str) -> dict: - """Update dictionary keys by adding a prefix.""" - return {f"{prefix}{key}": value for key, value in d.items()} - - -def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str): - # Break down orbax ckpts - they are in OCDBT - initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir) - # process projection params - keys = [ - "state_proj", - "action_in_proj", - "action_out_proj", - "action_time_mlp_in", - "action_time_mlp_out", - ] - - projection_params = {} - for key in keys: - kernel_params = initial_params["projection_params"][key]["kernel"] - bias_params = initial_params["projection_params"][key]["bias"] - if isinstance(kernel_params, dict): - weight = kernel_params["value"] - bias = bias_params["value"] - else: - weight = kernel_params - bias = bias_params - projection_params[f"{key}.weight"] = torch.from_numpy(np.array(weight)).T - projection_params[f"{key}.bias"] = torch.from_numpy(np.array(bias)) - - # Process PaliGemma weights - paligemma_config = get_paligemma_config(precision) - paligemma_params, gemma_raw_dictionary = slice_paligemma_state_dict( - initial_params["paligemma_params"], paligemma_config - ) - - # Process Gemma weights (at this stage they are unused) - gemma_config = get_gemma_config(precision) - gemma_params = slice_gemma_state_dict(gemma_raw_dictionary, config=gemma_config) - - # Instantiate model from configs - - if "pi0_aloha_sim" in checkpoint_dir: - pi0_config = PI0Config( - empty_cameras=2, - adapt_to_pi_aloha=True, - use_delta_joint_actions_aloha=False, - ) - elif "pi0_aloha_towel" in checkpoint_dir: - pi0_config = PI0Config( - adapt_to_pi_aloha=True, - use_delta_joint_actions_aloha=True, - ) - elif "pi0_base" in checkpoint_dir: - pi0_config = PI0Config( - empty_cameras=0, - adapt_to_pi_aloha=False, - use_delta_joint_actions_aloha=False, - ) - else: - raise ValueError() - - # gemma_config=gemma_config, paligemma_config=paligemma_config) - pi0_model = PI0Policy(pi0_config) - - paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.") - gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.") - projection_params = update_keys_with_prefix(projection_params, "model.") - - # load state dict - torch_dtype = PRECISIONS[precision] - pi0_model.load_state_dict({**paligemma_params, **gemma_params, **projection_params}) - pi0_model = pi0_model.to(torch_dtype) - # pi0_tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) - - pi0_model.save_pretrained(output_path, safe_serialization=True) - # pi0_tokenizer.save_pretrained(output_path, dtype=torch_dtype) - - # assert that model loads properly - del pi0_model - PI0Policy.from_pretrained(output_path) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument( - "--checkpoint_dir", - default="/raid/pablo/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params", - type=str, - help="Path to the ocdbt checkpoint", - ) - - parser.add_argument( - "--precision", - choices=["float32", "bfloat16", "float16"], - default="float32", - type=str, - help="Precision identifier for model conversion - should match the base checkpoint precision.", - ) - # tokenizer is identical to paligemma, it appears - - parser.add_argument( - "--tokenizer_hub_id", - default="google/paligemma-3b-pt-224", - type=str, - help="Hub path to the tokenizer to save", - ) - - parser.add_argument( - "--output_path", - required=True, - type=str, - help="Path to save converted weights to", - ) - - args = parser.parse_args() - convert_pi0_checkpoint( - checkpoint_dir=args.checkpoint_dir, - precision=args.precision, - tokenizer_id=args.tokenizer_hub_id, - output_path=args.output_path, - ) diff --git a/src/lerobot/policies/pi0/flex_attention.py b/src/lerobot/policies/pi0/flex_attention.py deleted file mode 100644 index 35628cddb..000000000 --- a/src/lerobot/policies/pi0/flex_attention.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn.functional as F # noqa: N812 -from packaging.version import Version - -if Version(torch.__version__) > Version("2.5.0"): - # Ffex attention is only available from torch 2.5 onwards - from torch.nn.attention.flex_attention import ( - _mask_mod_signature, - _round_up_to_multiple, - create_block_mask, - create_mask, - flex_attention, - ) - - -# @torch.compile(dynamic=False) -def flex_attention_forward( - attention_mask: torch.Tensor, - batch_size: int, - head_dim: int, - query_states: torch.Tensor, - key_states: torch.Tensor, - value_states: torch.Tensor, - scaling=None, -): - """ - This is defined out of classes to make compile happy. - """ - - original_dtype = query_states.dtype - num_att_heads = 8 - num_key_value_heads = 1 - num_key_value_groups = num_att_heads // num_key_value_heads - - key_states = key_states[:, :, :, None, :] - key_states = key_states.expand( - batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim - ) - key_states = key_states.reshape( - batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim - ) - - value_states = value_states[:, :, :, None, :] - value_states = value_states.expand( - batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim - ) - value_states = value_states.reshape( - batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim - ) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - value_states = value_states.transpose(1, 2) - - query_states = query_states.to(torch.float32) - key_states = key_states.to(torch.float32) - value_states = value_states.to(torch.float32) - - causal_mask = attention_mask - if causal_mask is not None: - causal_mask = causal_mask[:, None, :, : key_states.shape[2]] - - if causal_mask.shape[1] == 1 and query_states.shape[1] > 1: - causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1) - - def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature: - def mask_mod(b, h, q_idx, kv_idx): - # Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs. - return precomputed_mask[b][h][q_idx][kv_idx] - - return mask_mod - - b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask - - block_size = 128 - q_len_rounded = _round_up_to_multiple(q_len, block_size) - kv_len_rounded = _round_up_to_multiple(kv_len, block_size) - - # *CRITICAL* we do need to expand here, else we get a CUDA index error - - pad_q = q_len_rounded - q_len - pad_k = kv_len_rounded - kv_len - - padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0) - mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask) - - mask_4d = create_mask( - mod_fn=mask_mod_fn_orig, - B=b_mask, - H=h_mask, - Q_LEN=q_len_rounded, - KV_LEN=kv_len_rounded, - device=causal_mask.device, - _compile=False, - ) - - mask_mod_fn_padded = precomputed_mask_factory(mask_4d) - block_mask = create_block_mask( - mask_mod=mask_mod_fn_padded, - B=b_mask, - H=h_mask, - Q_LEN=q_len_rounded, - KV_LEN=kv_len_rounded, - BLOCK_SIZE=block_size, - device=causal_mask.device, - _compile=False, - ) - - # mask is applied inside the kernel, ideally more efficiently than score_mod. - attn_output, attention_weights = flex_attention( - query_states, - key_states, - value_states, - block_mask=block_mask, - enable_gqa=True, # because we shaped query/key states for GQA - scale=head_dim**-0.5 if scaling is None else scaling, - return_lse=True, - ) - - attn_output = attn_output.to(dtype=original_dtype) - attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim] - attn_output = attn_output.reshape( - batch_size, - -1, - attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim] - ) - return attn_output diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py deleted file mode 100644 index 66bd81e61..000000000 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ /dev/null @@ -1,705 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -π0: A Vision-Language-Action Flow Model for General Robot Control - -[Paper](https://www.physicalintelligence.company/download/pi0.pdf) -[Jax code](https://github.com/Physical-Intelligence/openpi) - -Designed by Physical Intelligence. Ported from Jax by Hugging Face. -Disclaimer: It is not expected to perform as well as the original implementation. - -Install pi0 extra dependencies: -```bash -pip install -e ".[pi0]" -``` - -Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`): -```bash -lerobot-train \ ---policy.path=lerobot/pi0 \ ---dataset.repo_id=danaaubakirova/koch_test -``` - -Example of finetuning the pi0 neural network with PaliGemma and expert Gemma -pretrained with VLM default parameters before pi0 finetuning: -```bash -lerobot-train \ ---policy.type=pi0 \ ---dataset.repo_id=danaaubakirova/koch_test -``` - -Example of using the pi0 pretrained model outside LeRobot training framework: -```python -policy = Pi0Policy.from_pretrained("lerobot/pi0") -``` - -""" - -import math -from collections import deque - -import torch -import torch.nn.functional as F # noqa: N812 -from torch import Tensor, nn - -from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE -from lerobot.policies.pi0.configuration_pi0 import PI0Config -from lerobot.policies.pi0.paligemma_with_expert import ( - PaliGemmaWithExpertConfig, - PaliGemmaWithExpertModel, -) -from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.utils.utils import get_safe_dtype - - -def create_sinusoidal_pos_embedding( - time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu" -) -> Tensor: - """Computes sine-cosine positional embedding vectors for scalar positions.""" - if dimension % 2 != 0: - raise ValueError(f"dimension ({dimension}) must be divisible by 2") - - if time.ndim != 1: - raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.") - - dtype = get_safe_dtype(torch.float64, device.type) - fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device) - period = min_period * (max_period / min_period) ** fraction - - # Compute the outer product - scaling_factor = 1.0 / period * 2 * math.pi - sin_input = scaling_factor[None, :] * time[:, None] - pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1) - return pos_emb - - -def make_att_2d_masks(pad_masks, att_masks): - """Copied from big_vision. - - Tokens can attend to valid inputs tokens which have a cumulative mask_ar - smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to - setup several types of attention, for example: - - [[1 1 1 1 1 1]]: pure causal attention. - - [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between - themselves and the last 3 tokens have a causal attention. The first - entry could also be a 1 without changing behaviour. - - [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a - block can attend all previous blocks and all tokens on the same block. - - Args: - input_mask: bool[B, N] true if its part of the input, false if padding. - mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on - it and 0 where it shares the same attention mask as the previous token. - """ - if att_masks.ndim != 2: - raise ValueError(att_masks.ndim) - if pad_masks.ndim != 2: - raise ValueError(pad_masks.ndim) - - cumsum = torch.cumsum(att_masks, dim=1) - att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None] - pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] - att_2d_masks = att_2d_masks & pad_2d_masks - return att_2d_masks - - -def resize_with_pad(img, width, height, pad_value=-1): - # assume no-op when width height fits already - if img.ndim != 4: - raise ValueError(f"(b,c,h,w) expected, but {img.shape}") - - cur_height, cur_width = img.shape[2:] - - ratio = max(cur_width / width, cur_height / height) - resized_height = int(cur_height / ratio) - resized_width = int(cur_width / ratio) - resized_img = F.interpolate( - img, size=(resized_height, resized_width), mode="bilinear", align_corners=False - ) - - pad_height = max(0, int(height - resized_height)) - pad_width = max(0, int(width - resized_width)) - - # pad on left and top of image - padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value) - return padded_img - - -def pad_vector(vector, new_dim): - """Can be (batch_size x sequence_length x features_dimension) - or (batch_size x features_dimension) - """ - if vector.shape[-1] == new_dim: - return vector - shape = list(vector.shape) - current_dim = shape[-1] - shape[-1] = new_dim - new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device) - new_vector[..., :current_dim] = vector - return new_vector - - -def normalize(x, min_val, max_val): - return (x - min_val) / (max_val - min_val) - - -def unnormalize(x, min_val, max_val): - return x * (max_val - min_val) + min_val - - -def safe_arcsin(value): - # This ensures that the input stays within - # [−1,1] to avoid invalid values for arcsin - return torch.arcsin(torch.clamp(value, -1.0, 1.0)) - - -def aloha_gripper_to_angular(value): - # Aloha transforms the gripper positions into a linear space. The following code - # reverses this transformation to be consistent with pi0 which is pretrained in - # angular space. - # - # These values are coming from the Aloha code: - # PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED - value = unnormalize(value, min_val=0.01844, max_val=0.05800) - - # This is the inverse of the angular to linear transformation inside the Interbotix code. - def linear_to_radian(linear_position, arm_length, horn_radius): - value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position) - return safe_arcsin(value) - - # The constants are taken from the Interbotix code. - value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022) - - # Normalize to [0, 1]. - # The values 0.4 and 1.5 were measured on an actual Trossen robot. - return normalize(value, min_val=0.4, max_val=1.5) - - -def aloha_gripper_from_angular(value): - # Convert from the gripper position used by pi0 to the gripper position that is used by Aloha. - # Note that the units are still angular but the range is different. - - # The values 0.4 and 1.5 were measured on an actual Trossen robot. - value = unnormalize(value, min_val=0.4, max_val=1.5) - - # These values are coming from the Aloha code: - # PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE - return normalize(value, min_val=-0.6213, max_val=1.4910) - - -def aloha_gripper_from_angular_inv(value): - # Directly inverts the gripper_from_angular function. - value = unnormalize(value, min_val=-0.6213, max_val=1.4910) - return normalize(value, min_val=0.4, max_val=1.5) - - -class PI0Policy(PreTrainedPolicy): - """Wrapper class around PI0FlowMatching model to train and run inference within LeRobot.""" - - config_class = PI0Config - name = "pi0" - - def __init__( - self, - config: PI0Config, - ): - """ - Args: - config: Policy configuration class instance or None, in which case the default instantiation of - the configuration class is used. - """ - - super().__init__(config) - config.validate_features() - self.config = config - - self.model = PI0FlowMatching(config) - - self.reset() - - def reset(self): - """This should be called whenever the environment is reset.""" - self._action_queue = deque([], maxlen=self.config.n_action_steps) - - def get_optim_params(self) -> dict: - return self.parameters() - - @classmethod - def from_pretrained(cls, *args, **kwargs): - """Override the from_pretrained method to display important disclaimer.""" - print( - "⚠️ DISCLAIMER: The PI0 model is ported from JAX by the Hugging Face team. \n" - " It is not expected to perform as well as the original implementation. \n" - " Original implementation: https://github.com/Physical-Intelligence/openpi" - ) - return super().from_pretrained(*args, **kwargs) - - @torch.no_grad() - def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: - """Predict a chunk of actions given environment observations.""" - raise NotImplementedError("Currently not implemented for PI0") - - @torch.no_grad() - def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: - """Select a single action given environment observations. - - This method wraps `select_actions` in order to return one action at a time for execution in the - environment. It works by managing the actions in a queue and only calling `select_actions` when the - queue is empty. - """ - self.eval() - - if self.config.adapt_to_pi_aloha: - batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - - # Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by - # querying the policy. - if len(self._action_queue) == 0: - images, img_masks = self.prepare_images(batch) - state = self.prepare_state(batch) - lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] - lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] - - actions = self.model.sample_actions( - images, img_masks, lang_tokens, lang_masks, state, noise=noise - ) - - # Unpad actions - original_action_dim = self.config.action_feature.shape[0] - actions = actions[:, :, :original_action_dim] - - if self.config.adapt_to_pi_aloha: - actions = self._pi_aloha_encode_actions(actions) - - # `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue - # effectively has shape (n_action_steps, batch_size, *), hence the transpose. - self._action_queue.extend(actions.transpose(0, 1)) - return self._action_queue.popleft() - - def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]: - """Do a full training forward pass to compute the loss""" - if self.config.adapt_to_pi_aloha: - batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE]) - batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION]) - - images, img_masks = self.prepare_images(batch) - state = self.prepare_state(batch) - lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] - lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] - actions = self.prepare_action(batch) - actions_is_pad = batch.get("action_is_pad") - - loss_dict = {} - losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) - loss_dict["losses_after_forward"] = losses.clone() - - if actions_is_pad is not None: - in_episode_bound = ~actions_is_pad - losses = losses * in_episode_bound.unsqueeze(-1) - loss_dict["losses_after_in_ep_bound"] = losses.clone() - - # Remove padding - losses = losses[:, :, : self.config.max_action_dim] - loss_dict["losses_after_rm_padding"] = losses.clone() - - # For backward pass - loss = losses.mean() - # For logging - loss_dict["l2_loss"] = loss.item() - - return loss, loss_dict - - def prepare_images(self, batch): - """Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and - convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP. - """ - images = [] - img_masks = [] - - present_img_keys = [key for key in self.config.image_features if key in batch] - missing_img_keys = [key for key in self.config.image_features if key not in batch] - - if len(present_img_keys) == 0: - raise ValueError( - f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})" - ) - - # Preprocess image features present in the batch - for key in present_img_keys: - img = batch[key] - - if self.config.resize_imgs_with_padding is not None: - img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0) - - # Normalize from range [0,1] to [-1,1] as expected by siglip - img = img * 2.0 - 1.0 - - bsize = img.shape[0] - device = img.device - mask = torch.ones(bsize, dtype=torch.bool, device=device) - images.append(img) - img_masks.append(mask) - - # Create image features not present in the batch - # as fully 0 padded images. - for num_empty_cameras in range(len(missing_img_keys)): - if num_empty_cameras >= self.config.empty_cameras: - break - img = torch.ones_like(img) * -1 - mask = torch.zeros_like(mask) - images.append(img) - img_masks.append(mask) - - return images, img_masks - - def _pi_aloha_decode_state(self, state): - # Flip the joints. - for motor_idx in [1, 2, 8, 9]: - state[:, motor_idx] *= -1 - # Reverse the gripper transformation that is being applied by the Aloha runtime. - for motor_idx in [6, 13]: - state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx]) - return state - - def _pi_aloha_encode_actions(self, actions): - # Flip the joints. - for motor_idx in [1, 2, 8, 9]: - actions[:, :, motor_idx] *= -1 - # Reverse the gripper transformation that is being applied by the Aloha runtime. - for motor_idx in [6, 13]: - actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx]) - return actions - - def _pi_aloha_encode_actions_inv(self, actions): - # Flip the joints again. - for motor_idx in [1, 2, 8, 9]: - actions[:, :, motor_idx] *= -1 - # Reverse the gripper transformation that is being applied by the Aloha runtime. - for motor_idx in [6, 13]: - actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx]) - return actions - - def prepare_state(self, batch): - """Pad state""" - state = pad_vector(batch[OBS_STATE], self.config.max_state_dim) - return state - - def prepare_action(self, batch): - """Pad action""" - actions = pad_vector(batch[ACTION], self.config.max_action_dim) - return actions - - -class PI0FlowMatching(nn.Module): - """ - π0: A Vision-Language-Action Flow Model for General Robot Control - - [Paper](https://www.physicalintelligence.company/download/pi0.pdf) - [Jax code](https://github.com/Physical-Intelligence/openpi) - - Designed by Physical Intelligence. Ported from Jax by Hugging Face. - ┌──────────────────────────────┐ - │ actions │ - │ ▲ │ - │ ┌┴─────┐ │ - │ kv cache │Gemma │ │ - │ ┌──────────►│Expert│ │ - │ │ │ │ │ - │ ┌┴────────┐ │x 10 │ │ - │ │ │ └▲──▲──┘ │ - │ │PaliGemma│ │ │ │ - │ │ │ │ robot state │ - │ │ │ noise │ - │ └▲──▲─────┘ │ - │ │ │ │ - │ │ image(s) │ - │ language tokens │ - └──────────────────────────────┘ - """ - - def __init__(self, config: PI0Config): - super().__init__() - self.config = config - - paligemma_with_export_config = PaliGemmaWithExpertConfig( - freeze_vision_encoder=self.config.freeze_vision_encoder, - train_expert_only=self.config.train_expert_only, - attention_implementation=self.config.attention_implementation, - ) - self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config) - - # Projections are float32 - self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width) - self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width) - self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim) - - self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width) - self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width) - - self.set_requires_grad() - - def set_requires_grad(self): - for params in self.state_proj.parameters(): - params.requires_grad = self.config.train_state_proj - - def sample_noise(self, shape, device): - noise = torch.normal( - mean=0.0, - std=1.0, - size=shape, - dtype=torch.float32, - device=device, - ) - return noise - - def sample_time(self, bsize, device): - beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0) - time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32) - time = time_beta * 0.999 + 0.001 - return time - - def embed_prefix( - self, images, img_masks, lang_tokens, lang_masks - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Embed images with SigLIP and language tokens with embedding layer to prepare - for PaliGemma transformer processing. - """ - # TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty - embs = [] - pad_masks = [] - att_masks = [] - - # TODO: remove for loop - for ( - img, - img_mask, - ) in zip(images, img_masks, strict=False): - img_emb = self.paligemma_with_expert.embed_image(img) - img_emb = img_emb.to(dtype=torch.bfloat16) - - # Normalize image embeddings - img_emb_dim = img_emb.shape[-1] - img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device) - - bsize, num_img_embs = img_emb.shape[:2] - img_mask = img_mask[:, None].expand(bsize, num_img_embs) - - embs.append(img_emb) - pad_masks.append(img_mask) - - # Create attention masks so that image tokens attend to each other - att_masks += [0] * num_img_embs - - lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens) - - # Normalize language embeddings - lang_emb_dim = lang_emb.shape[-1] - lang_emb = lang_emb * math.sqrt(lang_emb_dim) - - embs.append(lang_emb) - pad_masks.append(lang_masks) - - # full attention between image and language inputs - num_lang_embs = lang_emb.shape[1] - att_masks += [0] * num_lang_embs - - embs = torch.cat(embs, dim=1) - pad_masks = torch.cat(pad_masks, dim=1) - att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device) - att_masks = att_masks[None, :].expand(bsize, len(att_masks)) - - return embs, pad_masks, att_masks - - def embed_suffix(self, state, noisy_actions, timestep): - """Embed state, noisy_actions, timestep to prepare for Expert Gemma processing.""" - embs = [] - pad_masks = [] - att_masks = [] - - # Embed state - state_emb = self.state_proj(state) - state_emb = state_emb.to(dtype=torch.bfloat16) - embs.append(state_emb[:, None, :]) - bsize = state_emb.shape[0] - dtype = state_emb.dtype - device = state_emb.device - - state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device) - pad_masks.append(state_mask) - - # Set attention masks so that image and language inputs do not attend to state or actions - att_masks += [1] - - # Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1] - time_emb = create_sinusoidal_pos_embedding( - timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device - ) - time_emb = time_emb.type(dtype=dtype) - - # Fuse timestep + action information using an MLP - action_emb = self.action_in_proj(noisy_actions) - - time_emb = time_emb[:, None, :].expand_as(action_emb) - action_time_emb = torch.cat([action_emb, time_emb], dim=2) - - action_time_emb = self.action_time_mlp_in(action_time_emb) - action_time_emb = F.silu(action_time_emb) # swish == silu - action_time_emb = self.action_time_mlp_out(action_time_emb) - - # Add to input tokens - embs.append(action_time_emb) - - bsize, action_time_dim = action_time_emb.shape[:2] - action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device) - pad_masks.append(action_time_mask) - - # Set attention masks so that image, language and state inputs do not attend to action tokens - att_masks += [1] + ([0] * (self.config.n_action_steps - 1)) - - embs = torch.cat(embs, dim=1) - pad_masks = torch.cat(pad_masks, dim=1) - att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device) - att_masks = att_masks[None, :].expand(bsize, len(att_masks)) - - return embs, pad_masks, att_masks - - def forward( - self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None - ) -> Tensor: - """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" - if noise is None: - noise = self.sample_noise(actions.shape, actions.device) - - if time is None: - time = self.sample_time(actions.shape[0], actions.device) - - time_expanded = time[:, None, None] - x_t = time_expanded * noise + (1 - time_expanded) * actions - u_t = noise - actions - - prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( - images, img_masks, lang_tokens, lang_masks - ) - suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time) - - pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1) - att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1) - - att_2d_masks = make_att_2d_masks(pad_masks, att_masks) - position_ids = torch.cumsum(pad_masks, dim=1) - 1 - - (_, suffix_out), _ = self.paligemma_with_expert.forward( - attention_mask=att_2d_masks, - position_ids=position_ids, - past_key_values=None, - inputs_embeds=[prefix_embs, suffix_embs], - use_cache=False, - fill_kv_cache=False, - ) - suffix_out = suffix_out[:, -self.config.n_action_steps :] - # Original openpi code, upcast attention output - suffix_out = suffix_out.to(dtype=torch.float32) - v_t = self.action_out_proj(suffix_out) - - losses = F.mse_loss(u_t, v_t, reduction="none") - return losses - - def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor: - """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" - bsize = state.shape[0] - device = state.device - - if noise is None: - actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim) - noise = self.sample_noise(actions_shape, device) - - prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( - images, img_masks, lang_tokens, lang_masks - ) - prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks) - prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 - - # Compute image and language key value cache - _, past_key_values = self.paligemma_with_expert.forward( - attention_mask=prefix_att_2d_masks, - position_ids=prefix_position_ids, - past_key_values=None, - inputs_embeds=[prefix_embs, None], - use_cache=self.config.use_cache, - fill_kv_cache=True, - ) - - dt = -1.0 / self.config.num_steps - dt = torch.tensor(dt, dtype=torch.float32, device=device) - - x_t = noise - time = torch.tensor(1.0, dtype=torch.float32, device=device) - while time >= -dt / 2: - expanded_time = time.expand(bsize) - v_t = self.denoise_step( - state, - prefix_pad_masks, - past_key_values, - x_t, - expanded_time, - ) - - # Euler step - x_t += dt * v_t - time += dt - return x_t - - def denoise_step( - self, - state, - prefix_pad_masks, - past_key_values, - x_t, - timestep, - ): - """Apply one denoising step of the noise `x_t` at a given timestep.""" - suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep) - - suffix_len = suffix_pad_masks.shape[1] - batch_size = prefix_pad_masks.shape[0] - prefix_len = prefix_pad_masks.shape[1] - prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len) - - suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks) - - full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2) - - prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None] - position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1 - - outputs_embeds, _ = self.paligemma_with_expert.forward( - attention_mask=full_att_2d_masks, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=[None, suffix_embs], - use_cache=self.config.use_cache, - fill_kv_cache=False, - ) - suffix_out = outputs_embeds[1] - suffix_out = suffix_out[:, -self.config.n_action_steps :] - suffix_out = suffix_out.to(dtype=torch.float32) - v_t = self.action_out_proj(suffix_out) - return v_t diff --git a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py b/src/lerobot/policies/pi0/modeling_pi0openpi.py similarity index 99% rename from src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py rename to src/lerobot/policies/pi0/modeling_pi0openpi.py index 4db15a3d3..7be238889 100644 --- a/src/lerobot/policies/pi0_openpi/modeling_pi0openpi.py +++ b/src/lerobot/policies/pi0/modeling_pi0openpi.py @@ -33,7 +33,7 @@ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditi from lerobot.configs.policies import PreTrainedConfig from lerobot.constants import ACTION, OBS_STATE from lerobot.policies.normalize import Normalize, Unnormalize -from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig +from lerobot.policies.pi0.configuration_pi0openpi import PI0OpenPIConfig from lerobot.policies.pretrained import PreTrainedPolicy, T @@ -525,7 +525,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` msg = """transformers_replace is not installed correctly. Please install it with `pip install transformers==4.53.2` -and `cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* \ +and `cp -r ./src/lerobot/policies/pi0/transformers_replace/* \ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")`""" try: @@ -846,7 +846,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy): """PI0 OpenPI Policy for LeRobot.""" config_class = PI0OpenPIConfig - name = "pi0_openpi" + name = "pi0" def __init__( # see lerobot pi0 `__init__` self, diff --git a/src/lerobot/policies/pi0/paligemma_with_expert.py b/src/lerobot/policies/pi0/paligemma_with_expert.py deleted file mode 100644 index edc34b7c5..000000000 --- a/src/lerobot/policies/pi0/paligemma_with_expert.py +++ /dev/null @@ -1,420 +0,0 @@ -# Copyright 2024 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -import torch.version -from pytest import Cache -from torch import nn -from transformers import ( - AutoConfig, - GemmaForCausalLM, - PaliGemmaForConditionalGeneration, - PretrainedConfig, - PreTrainedModel, -) -from transformers.models.auto import CONFIG_MAPPING - -from lerobot.policies.pi0.flex_attention import flex_attention_forward - - -def apply_rope(x, positions, max_wavelength=10_000): - """ - Applies RoPE positions [B, L] to x [B, L, H, D]. - """ - d_half = x.shape[-1] // 2 - device = x.device - dtype = x.dtype - x = x.to(torch.float32) - - freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device) - timescale = max_wavelength**freq_exponents - radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32) - - radians = radians[..., None, :] - - sin = torch.sin(radians) # .to(dtype=dtype) - cos = torch.cos(radians) # .to(dtype=dtype) - - x1, x2 = x.split(d_half, dim=-1) - res = torch.empty_like(x) - res[..., :d_half] = x1 * cos - x2 * sin - res[..., d_half:] = x2 * cos + x1 * sin - - return res.to(dtype) - - -class PaliGemmaWithExpertConfig(PretrainedConfig): - model_type = "PaliGemmaWithExpertModel" - sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig} - - def __init__( - self, - paligemma_config: dict | None = None, - gemma_expert_config: dict | None = None, - freeze_vision_encoder: bool = True, - train_expert_only: bool = True, - attention_implementation: str = "eager", - **kwargs, - ): - self.freeze_vision_encoder = freeze_vision_encoder - self.train_expert_only = train_expert_only - self.attention_implementation = attention_implementation - - if paligemma_config is None: - # Default config from Pi0 - self.paligemma_config = CONFIG_MAPPING["paligemma"]( - transformers_version="4.48.1", - _vocab_size=257152, - bos_token_id=2, - eos_token_id=1, - hidden_size=2048, - image_token_index=257152, - model_type="paligemma", - pad_token_id=0, - projection_dim=2048, - text_config={ - "hidden_activation": "gelu_pytorch_tanh", - "hidden_size": 2048, - "intermediate_size": 16384, - "model_type": "gemma", - "num_attention_heads": 8, - "num_hidden_layers": 18, - "num_image_tokens": 256, - "num_key_value_heads": 1, - "torch_dtype": "float32", - "vocab_size": 257152, - }, - vision_config={ - "hidden_size": 1152, - "intermediate_size": 4304, - "model_type": "siglip_vision_model", - "num_attention_heads": 16, - "num_hidden_layers": 27, - "num_image_tokens": 256, - "patch_size": 14, - "projection_dim": 2048, - "projector_hidden_act": "gelu_fast", - "torch_dtype": "float32", - "vision_use_head": False, - }, - ) - elif isinstance(self.paligemma_config, dict): - # Override Pi0 default config for PaliGemma - if "model_type" not in gemma_expert_config: - paligemma_config["model_type"] = "paligemma" - - cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]] - self.paligemma_config = cfg_cls(**paligemma_config) - - if gemma_expert_config is None: - # Default config from Pi0 - self.gemma_expert_config = CONFIG_MAPPING["gemma"]( - attention_bias=False, - attention_dropout=0.0, - bos_token_id=2, - eos_token_id=1, - head_dim=256, - hidden_act="gelu_pytorch_tanh", - hidden_activation="gelu_pytorch_tanh", - hidden_size=1024, - initializer_range=0.02, - intermediate_size=4096, - max_position_embeddings=8192, - model_type="gemma", - num_attention_heads=8, - num_hidden_layers=18, - num_key_value_heads=1, - pad_token_id=0, - rms_norm_eps=1e-06, - rope_theta=10000.0, - torch_dtype="float32", - transformers_version="4.48.1", - use_cache=True, - vocab_size=257152, - ) - elif isinstance(self.gemma_expert_config, dict): - # Override Pi0 default config for Gemma Expert - if "model_type" not in gemma_expert_config: - gemma_expert_config["model_type"] = "gemma" - - cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]] - self.gemma_expert_config = cfg_cls(**gemma_expert_config) - - super().__init__(**kwargs) - - def __post_init__(self): - super().__post_init__() - if self.train_expert_only and not self.freeze_vision_encoder: - raise ValueError( - "You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible." - ) - - if self.attention_implementation not in ["eager", "fa2", "flex"]: - raise ValueError( - f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'." - ) - - -class PaliGemmaWithExpertModel(PreTrainedModel): - config_class = PaliGemmaWithExpertConfig - - def __init__(self, config: PaliGemmaWithExpertConfig): - super().__init__(config=config) - self.config = config - self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config) - self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config) - # Remove unused embed_tokens - self.gemma_expert.model.embed_tokens = None - - self.to_bfloat16_like_physical_intelligence() - self.set_requires_grad() - - def set_requires_grad(self): - if self.config.freeze_vision_encoder: - self.paligemma.vision_tower.eval() - for params in self.paligemma.vision_tower.parameters(): - params.requires_grad = False - - if self.config.train_expert_only: - self.paligemma.eval() - for params in self.paligemma.parameters(): - params.requires_grad = False - - def train(self, mode: bool = True): - super().train(mode) - - if self.config.freeze_vision_encoder: - self.paligemma.vision_tower.eval() - - if self.config.train_expert_only: - self.paligemma.eval() - - def to_bfloat16_like_physical_intelligence(self): - self.paligemma = self.paligemma.to(dtype=torch.bfloat16) - - params_to_change_dtype = [ - "language_model.model.layers", - "gemma_expert.model.layers", - "vision_tower", - "multi_modal", - ] - for name, param in self.named_parameters(): - if any(selector in name for selector in params_to_change_dtype): - param.data = param.data.to(dtype=torch.bfloat16) - - def embed_image(self, image: torch.Tensor): - # Handle different transformers versions - if hasattr(self.paligemma, "get_image_features"): - return self.paligemma.get_image_features(image) - else: - return self.paligemma.model.get_image_features(image) - - def embed_language_tokens(self, tokens: torch.Tensor): - return self.paligemma.language_model.embed_tokens(tokens) - - # TODO: break down this huge forward into modules or functions - def forward( - self, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: list[torch.FloatTensor] | Cache | None = None, - inputs_embeds: list[torch.FloatTensor] = None, - use_cache: bool | None = None, - fill_kv_cache: bool | None = None, - ): - models = [self.paligemma.language_model, self.gemma_expert.model] - - for hidden_states in inputs_embeds: - # TODO this is very inefficient - # dtype is always the same, batch size too (if > 1 len) - # device could be trickier in multi gpu edge cases but that's it - if hidden_states is None: - continue - batch_size = hidden_states.shape[0] - - # RMSNorm - num_layers = self.paligemma.config.text_config.num_hidden_layers - head_dim = self.paligemma.config.text_config.head_dim - for layer_idx in range(num_layers): - query_states = [] - key_states = [] - value_states = [] - for i, hidden_states in enumerate(inputs_embeds): - if hidden_states is None: - continue - layer = models[i].layers[layer_idx] - # normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype) - # hidden_states = hidden_states * normalizer - hidden_states = layer.input_layernorm(hidden_states) - - input_shape = hidden_states.shape[:-1] - hidden_shape = (*input_shape, -1, layer.self_attn.head_dim) - - hidden_states = hidden_states.to(dtype=torch.bfloat16) - query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape) - key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape) - value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape) - - query_states.append(query_state) - key_states.append(key_state) - value_states.append(value_state) - - # B,L,H,D with L sequence length, H number of heads, D head dim - # concatenate on the number of embeddings/tokens - query_states = torch.cat(query_states, dim=1) - key_states = torch.cat(key_states, dim=1) - value_states = torch.cat(value_states, dim=1) - - query_states = apply_rope(query_states, position_ids) - key_states = apply_rope(key_states, position_ids) - - if use_cache and past_key_values is None: - past_key_values = {} - - if use_cache: - if fill_kv_cache: - past_key_values[layer_idx] = { - "key_states": key_states, - "value_states": value_states, - } - else: - # TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before. - # so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach - # the max len, then we (for instance) double the cache size. This implementation already exists - # in `transformers`. (molbap) - key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1) - value_states = torch.cat( - [past_key_values[layer_idx]["value_states"], value_states], dim=1 - ) - - attention_interface = self.get_attention_interface() - att_output = attention_interface( - attention_mask, batch_size, head_dim, query_states, key_states, value_states - ) - att_output = att_output.to(dtype=torch.bfloat16) - - # first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len]) - outputs_embeds = [] - start = 0 - for i, hidden_states in enumerate(inputs_embeds): - layer = models[i].layers[layer_idx] - - if hidden_states is not None: - end = start + hidden_states.shape[1] - - if att_output.dtype != layer.self_attn.o_proj.weight.dtype: - att_output = att_output.to(layer.self_attn.o_proj.weight.dtype) - out_emb = layer.self_attn.o_proj(att_output[:, start:end]) - - # TODO: first dropout (by default 0.0) - - # first residual - out_emb += hidden_states - after_first_residual = out_emb.clone() - - out_emb = layer.post_attention_layernorm(out_emb) - out_emb = layer.mlp(out_emb) - - # TODO: second dropout (by default 0.0) - - # second residual - out_emb += after_first_residual - - outputs_embeds.append(out_emb) - - start = end - else: - outputs_embeds.append(None) - - inputs_embeds = outputs_embeds - - # final norm - outputs_embeds = [] - for i, hidden_states in enumerate(inputs_embeds): - if hidden_states is not None: - out_emb = models[i].norm(hidden_states) - outputs_embeds.append(out_emb) - else: - outputs_embeds.append(None) - - return outputs_embeds, past_key_values - - def get_attention_interface(self): - if self.config.attention_implementation == "fa2": - attention_interface = self.flash_attention_forward - elif self.config.attention_implementation == "flex": - attention_interface = flex_attention_forward - else: - attention_interface = self.eager_attention_forward - return attention_interface - - def flash_attention_forward( - self, attention_mask, batch_size, head_dim, query_states, key_states, value_states - ): - raise NotImplementedError("FA2 is not implemented (yet)") - - def eager_attention_forward( - self, attention_mask, batch_size, head_dim, query_states, key_states, value_states - ): - num_att_heads = self.config.paligemma_config.text_config.num_attention_heads - num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads - num_key_value_groups = num_att_heads // num_key_value_heads - - # query_states: batch_size, sequence_length, num_att_head, head_dim - # key_states: batch_size, sequence_length, num_key_value_head, head_dim - # value_states: batch_size, sequence_length, num_key_value_head, head_dim - sequence_length = key_states.shape[1] - - key_states = key_states[:, :, :, None, :].expand( - batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim - ) - key_states = key_states.reshape( - batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim - ) - - value_states = value_states[:, :, :, None, :].expand( - batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim - ) - value_states = value_states.reshape( - batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim - ) - - # Attention here is upcasted to float32 to match the original eager implementation. - - query_states = query_states.to(dtype=torch.float32) - key_states = key_states.to(dtype=torch.float32) - - query_states = query_states.transpose(1, 2) - key_states = key_states.transpose(1, 2) - - att_weights = torch.matmul(query_states, key_states.transpose(2, 3)) - att_weights *= head_dim**-0.5 - big_neg = -2.3819763e38 # See gemma/modules.py - - masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg) - - probs = nn.functional.softmax(masked_att_weights, dim=-1) - probs = probs.to(dtype=value_states.dtype) - - # probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length - # value_states: batch_size, sequence_length, num_att_heads, head_dim - - att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3)) - - att_output = att_output.permute(0, 2, 1, 3) - # we use -1 because sequence length can change - att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim) - - return att_output diff --git a/src/lerobot/policies/pi0/processor_pi0.py b/src/lerobot/policies/pi0/processor_pi0.py deleted file mode 100644 index cd9712201..000000000 --- a/src/lerobot/policies/pi0/processor_pi0.py +++ /dev/null @@ -1,166 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any - -import torch - -from lerobot.configs.types import PipelineFeatureType, PolicyFeature -from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME -from lerobot.policies.pi0.configuration_pi0 import PI0Config -from lerobot.processor import ( - AddBatchDimensionProcessorStep, - ComplementaryDataProcessorStep, - DeviceProcessorStep, - NormalizerProcessorStep, - PolicyAction, - PolicyProcessorPipeline, - ProcessorStep, - ProcessorStepRegistry, - RenameObservationsProcessorStep, - TokenizerProcessorStep, - UnnormalizerProcessorStep, -) -from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action - - -@ProcessorStepRegistry.register(name="pi0_new_line_processor") -class Pi0NewLineProcessor(ComplementaryDataProcessorStep): - """ - Ensures that the task description string ends with a newline character. - - This processing step is required for compatibility with the PaliGemma tokenizer, - which expects a newline at the end of the text prompt. It handles both single - strings and lists of strings for the 'task' key in complementary data. - """ - - def complementary_data(self, complementary_data): - """ - Adds a newline to the 'task' field if it doesn't already have one. - - Args: - complementary_data: A dictionary that may contain a 'task' key with a - string or list of strings. - - Returns: - A new dictionary with the modified 'task' field. - """ - if "task" not in complementary_data: - return complementary_data - - task = complementary_data["task"] - if task is None: - return complementary_data - - new_complementary_data = dict(complementary_data) - - # Handle both string and list of strings - if isinstance(task, str): - # Single string: add newline if not present - if not task.endswith("\n"): - new_complementary_data["task"] = f"{task}\n" - elif isinstance(task, list) and all(isinstance(t, str) for t in task): - # List of strings: add newline to each if not present - new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task] - # If task is neither string nor list of strings, leave unchanged - - return new_complementary_data - - def transform_features( - self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] - ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: - """ - This step does not alter the feature definitions. - - Args: - features: The input feature dictionary. - - Returns: - The unchanged feature dictionary. - """ - return features - - -def make_pi0_pre_post_processors( - config: PI0Config, - dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, -) -> tuple[ - PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], - PolicyProcessorPipeline[PolicyAction, PolicyAction], -]: - """ - Constructs pre-processor and post-processor pipelines for the PI0 policy. - - The pre-processing pipeline prepares input data for the model by: - 1. Renaming features to match pretrained configurations. - 2. Normalizing input and output features based on dataset statistics. - 3. Adding a batch dimension. - 4. Appending a newline character to the task description for tokenizer compatibility. - 5. Tokenizing the text prompt using the PaliGemma tokenizer. - 6. Moving all data to the specified device. - - The post-processing pipeline handles the model's output by: - 1. Moving data to the CPU. - 2. Unnormalizing the output features to their original scale. - - Args: - config: The configuration object for the PI0 policy. - dataset_stats: A dictionary of statistics for normalization. - preprocessor_kwargs: Additional arguments for the pre-processor pipeline. - postprocessor_kwargs: Additional arguments for the post-processor pipeline. - - Returns: - A tuple containing the configured pre-processor and post-processor pipelines. - """ - - # Add remaining processors - input_steps: list[ProcessorStep] = [ - RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one - AddBatchDimensionProcessorStep(), - Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma - TokenizerProcessorStep( - tokenizer_name="google/paligemma-3b-pt-224", - max_length=config.tokenizer_max_length, - padding_side="right", - padding="max_length", - ), - DeviceProcessorStep(device=config.device), - NormalizerProcessorStep( - features={**config.input_features, **config.output_features}, - norm_map=config.normalization_mapping, - stats=dataset_stats, - ), - ] - - output_steps: list[ProcessorStep] = [ - UnnormalizerProcessorStep( - features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats - ), - DeviceProcessorStep(device="cpu"), - ] - - return ( - PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( - steps=input_steps, - name=POLICY_PREPROCESSOR_DEFAULT_NAME, - ), - PolicyProcessorPipeline[PolicyAction, PolicyAction]( - steps=output_steps, - name=POLICY_POSTPROCESSOR_DEFAULT_NAME, - to_transition=policy_action_to_transition, - to_output=transition_to_policy_action, - ), - ) diff --git a/src/lerobot/policies/pi05_openpi/transformers_replace/models/gemma/configuration_gemma.py b/src/lerobot/policies/pi0/transformers_replace/models/gemma/configuration_gemma.py similarity index 100% rename from src/lerobot/policies/pi05_openpi/transformers_replace/models/gemma/configuration_gemma.py rename to src/lerobot/policies/pi0/transformers_replace/models/gemma/configuration_gemma.py diff --git a/src/lerobot/policies/pi05_openpi/transformers_replace/models/gemma/modeling_gemma.py b/src/lerobot/policies/pi0/transformers_replace/models/gemma/modeling_gemma.py similarity index 100% rename from src/lerobot/policies/pi05_openpi/transformers_replace/models/gemma/modeling_gemma.py rename to src/lerobot/policies/pi0/transformers_replace/models/gemma/modeling_gemma.py diff --git a/src/lerobot/policies/pi05_openpi/transformers_replace/models/paligemma/modeling_paligemma.py b/src/lerobot/policies/pi0/transformers_replace/models/paligemma/modeling_paligemma.py similarity index 100% rename from src/lerobot/policies/pi05_openpi/transformers_replace/models/paligemma/modeling_paligemma.py rename to src/lerobot/policies/pi0/transformers_replace/models/paligemma/modeling_paligemma.py diff --git a/src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/check.py b/src/lerobot/policies/pi0/transformers_replace/models/siglip/check.py similarity index 100% rename from src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/check.py rename to src/lerobot/policies/pi0/transformers_replace/models/siglip/check.py diff --git a/src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/modeling_siglip.py b/src/lerobot/policies/pi0/transformers_replace/models/siglip/modeling_siglip.py similarity index 100% rename from src/lerobot/policies/pi05_openpi/transformers_replace/models/siglip/modeling_siglip.py rename to src/lerobot/policies/pi0/transformers_replace/models/siglip/modeling_siglip.py diff --git a/src/lerobot/policies/pi05_openpi/README.md b/src/lerobot/policies/pi05/README.md similarity index 100% rename from src/lerobot/policies/pi05_openpi/README.md rename to src/lerobot/policies/pi05/README.md diff --git a/src/lerobot/policies/pi05_openpi/__init__.py b/src/lerobot/policies/pi05/__init__.py similarity index 100% rename from src/lerobot/policies/pi05_openpi/__init__.py rename to src/lerobot/policies/pi05/__init__.py diff --git a/src/lerobot/policies/pi05_openpi/configuration_pi05openpi.py b/src/lerobot/policies/pi05/configuration_pi05openpi.py similarity index 99% rename from src/lerobot/policies/pi05_openpi/configuration_pi05openpi.py rename to src/lerobot/policies/pi05/configuration_pi05openpi.py index 3b8b4779e..3dc4436cb 100644 --- a/src/lerobot/policies/pi05_openpi/configuration_pi05openpi.py +++ b/src/lerobot/policies/pi05/configuration_pi05openpi.py @@ -22,7 +22,7 @@ from lerobot.optim.optimizers import AdamWConfig from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig -@PreTrainedConfig.register_subclass("pi05_openpi") +@PreTrainedConfig.register_subclass("pi05") @dataclass class PI05OpenPIConfig(PreTrainedConfig): # Model architecture diff --git a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py b/src/lerobot/policies/pi05/modeling_pi05openpi.py similarity index 99% rename from src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py rename to src/lerobot/policies/pi05/modeling_pi05openpi.py index 6b6d328e1..eb6f95934 100644 --- a/src/lerobot/policies/pi05_openpi/modeling_pi05openpi.py +++ b/src/lerobot/policies/pi05/modeling_pi05openpi.py @@ -34,7 +34,7 @@ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditi from lerobot.configs.policies import PreTrainedConfig from lerobot.constants import ACTION, OBS_STATE from lerobot.policies.normalize import Normalize, Unnormalize -from lerobot.policies.pi05_openpi.configuration_pi05openpi import PI05OpenPIConfig +from lerobot.policies.pi05.configuration_pi05openpi import PI05OpenPIConfig from lerobot.policies.pretrained import PreTrainedPolicy, T @@ -525,7 +525,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` msg = """transformers_replace is not installed correctly. Please install it with `pip install transformers==4.53.2` -and `cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* \ +and `cp -r ./src/lerobot/policies/pi0/transformers_replace/* \ $(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")`""" try: @@ -820,7 +820,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy): """PI05 OpenPI Policy for LeRobot.""" config_class = PI05OpenPIConfig - name = "pi05_openpi" + name = "pi05" def __init__( # see lerobot pi0 `__init__` self, diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/configuration_gemma.py b/src/lerobot/policies/pi05/transformers_replace/models/gemma/configuration_gemma.py similarity index 100% rename from src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/configuration_gemma.py rename to src/lerobot/policies/pi05/transformers_replace/models/gemma/configuration_gemma.py diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py b/src/lerobot/policies/pi05/transformers_replace/models/gemma/modeling_gemma.py similarity index 100% rename from src/lerobot/policies/pi0_openpi/transformers_replace/models/gemma/modeling_gemma.py rename to src/lerobot/policies/pi05/transformers_replace/models/gemma/modeling_gemma.py diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/paligemma/modeling_paligemma.py b/src/lerobot/policies/pi05/transformers_replace/models/paligemma/modeling_paligemma.py similarity index 100% rename from src/lerobot/policies/pi0_openpi/transformers_replace/models/paligemma/modeling_paligemma.py rename to src/lerobot/policies/pi05/transformers_replace/models/paligemma/modeling_paligemma.py diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/check.py b/src/lerobot/policies/pi05/transformers_replace/models/siglip/check.py similarity index 100% rename from src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/check.py rename to src/lerobot/policies/pi05/transformers_replace/models/siglip/check.py diff --git a/src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/modeling_siglip.py b/src/lerobot/policies/pi05/transformers_replace/models/siglip/modeling_siglip.py similarity index 100% rename from src/lerobot/policies/pi0_openpi/transformers_replace/models/siglip/modeling_siglip.py rename to src/lerobot/policies/pi05/transformers_replace/models/siglip/modeling_siglip.py diff --git a/src/lerobot/scripts/server/constants.py b/src/lerobot/scripts/server/constants.py index 69dbffbcc..5ebf3780c 100644 --- a/src/lerobot/scripts/server/constants.py +++ b/src/lerobot/scripts/server/constants.py @@ -23,7 +23,7 @@ DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS DEFAULT_OBS_QUEUE_TIMEOUT = 2 # All action chunking policies -SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "pi0", "tdmpc", "vqbet", "pi0_openpi", "pi05_openpi"] +SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"] # TODO: Add all other robots SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"] diff --git a/src/lerobot/scripts/server/helpers.py b/src/lerobot/scripts/server/helpers.py index d8051b76e..41edc5850 100644 --- a/src/lerobot/scripts/server/helpers.py +++ b/src/lerobot/scripts/server/helpers.py @@ -26,7 +26,6 @@ from lerobot.constants import OBS_IMAGES, OBS_STATE from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features # NOTE: Configs need to be loaded for the client to be able to instantiate the policy config -from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401 from lerobot.robots.robot import Robot from lerobot.utils.utils import init_logging diff --git a/src/lerobot/templates/lerobot_modelcard_template.md b/src/lerobot/templates/lerobot_modelcard_template.md index 9edd4bce7..34af282b0 100644 --- a/src/lerobot/templates/lerobot_modelcard_template.md +++ b/src/lerobot/templates/lerobot_modelcard_template.md @@ -19,11 +19,9 @@ [Diffusion Policy](https://huggingface.co/papers/2303.04137) treats visuomotor control as a generative diffusion process, producing smooth, multi-step action trajectories that excel at contact-rich manipulation. {% elif model_name == "vqbet" %} [VQ-BET](https://huggingface.co/papers/2403.03181) combines vector-quantised action tokens with Behaviour Transformers to discretise control and achieve data-efficient imitation across diverse skills. -{% elif model_name == "pi0" %} -[Pi0](https://huggingface.co/papers/2410.24164) is a generalist vision-language-action transformer that converts multimodal observations and text instructions into robot actions for zero-shot task transfer. {% elif model_name == "pi0fast" %} [Pi0-Fast](https://huggingface.co/papers/2501.09747) is a variant of Pi0 that uses a new tokenization method called FAST, which enables training of an autoregressive vision-language-action policy for high-frequency robotic tasks with improved performance and reduced training time. -{% elif model_name == "pi0_openpi" %} +{% elif model_name == "pi0" %} **π₀ (Pi0)** π₀ is a Vision-Language-Action model for general robot control, from Physical Intelligence. The LeRobot implementation is adapted from their open source OpenPI repository. @@ -33,7 +31,7 @@ π₀ represents a breakthrough in robotics as the first general-purpose robot foundation model developed by Physical Intelligence. Unlike traditional robots that are narrow specialists programmed for repetitive motions, π₀ is designed to be a generalist policy that can understand visual inputs, interpret natural language instructions, and control a variety of different robots across diverse tasks. For more details, see the [Physical Intelligence π₀ blog post](https://www.physicalintelligence.company/blog/pi0). -{% elif model_name == "pi05_openpi" %} +{% elif model_name == "pi05" %} **π₀.₅ (Pi05) Policy** π₀.₅ is a Vision-Language-Action model with open-world generalization, from Physical Intelligence. The LeRobot implementation is adapted from their open source OpenPI repository. diff --git a/tests/policies/pi0_pi05/test_pi05_openpi.py b/tests/policies/pi0_pi05/test_pi05_openpi.py index 135f3265c..26e9bd948 100644 --- a/tests/policies/pi0_pi05/test_pi05_openpi.py +++ b/tests/policies/pi0_pi05/test_pi05_openpi.py @@ -13,7 +13,7 @@ pytestmark = pytest.mark.skipif( reason="This test requires local OpenPI installation and is not meant for CI", ) -from lerobot.policies.pi05_openpi import PI05OpenPIConfig, PI05OpenPIPolicy # noqa: E402 +from lerobot.policies.pi05 import PI05Config, PI05Policy # noqa: E402 from tests.utils import require_cuda # noqa: E402 @@ -22,7 +22,7 @@ def test_pi05_model_architecture(): """Test that pi05=True creates the correct model architecture.""" # Create config - config = PI05OpenPIConfig( + config = PI05Config( max_action_dim=7, max_state_dim=14, dtype="float32", @@ -73,7 +73,7 @@ def test_pi05_model_architecture(): } # Instantiate policy - policy = PI05OpenPIPolicy(config, dataset_stats) + policy = PI05Policy(config, dataset_stats) # Verify pi05 model components exist # Check that time_mlp layers exist (for AdaRMS conditioning) @@ -104,7 +104,7 @@ def test_pi05_forward_pass(): """Test forward pass with""" # Create config - config = PI05OpenPIConfig( + config = PI05Config( max_action_dim=7, max_state_dim=14, dtype="float32", @@ -150,7 +150,7 @@ def test_pi05_forward_pass(): } # Instantiate policy - policy = PI05OpenPIPolicy(config, dataset_stats) + policy = PI05Policy(config, dataset_stats) # Create test batch batch_size = 2 diff --git a/tests/policies/pi0_pi05/test_pi0_openpi.py b/tests/policies/pi0_pi05/test_pi0_openpi.py index f48efd5e0..472f143de 100644 --- a/tests/policies/pi0_pi05/test_pi0_openpi.py +++ b/tests/policies/pi0_pi05/test_pi0_openpi.py @@ -14,7 +14,7 @@ pytestmark = pytest.mark.skipif( ) from lerobot.policies.factory import make_policy_config # noqa: E402 -from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy # noqa: E402 +from lerobot.policies.pi0 import PI0OpenPIConfig, PI0OpenPIPolicy # noqa: E402 from tests.utils import require_cuda # noqa: E402 @@ -96,7 +96,7 @@ def test_config_creation(): """Test policy config creation through factory.""" try: config = make_policy_config( - policy_type="pi0_openpi", + policy_type="pi0", max_action_dim=7, max_state_dim=14, ) diff --git a/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py b/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py index 25e406372..47a2ddeab 100644 --- a/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py +++ b/tests/policies/pi0_pi05/test_pi0_original_vs_lerobot.py @@ -21,7 +21,7 @@ from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402 from transformers import AutoTokenizer # noqa: E402 -from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy # noqa: E402 +from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402 DUMMY_ACTION_DIM = 32 DUMMY_STATE_DIM = 32 @@ -68,9 +68,7 @@ class PI0BaseOriginalConfig: def instantiate_lerobot_pi0(from_pretrained: bool = False): if from_pretrained: # Load the policy first - policy = PI0OpenPIPolicy.from_pretrained( - pretrained_name_or_path="pepijn223/pi0_base_fp32", strict=True - ) + policy = PI0Policy.from_pretrained(pretrained_name_or_path="pepijn223/pi0_base_fp32", strict=True) # Then reinitialize the normalization with proper stats from lerobot.policies.normalize import Normalize, Unnormalize @@ -84,10 +82,8 @@ def instantiate_lerobot_pi0(from_pretrained: bool = False): policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS ) else: - config = PI0OpenPIConfig( - max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32" - ) - policy = PI0OpenPIPolicy(config, DUMMY_DATASET_STATS) + config = PI0Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32") + policy = PI0Policy(config, DUMMY_DATASET_STATS) policy.to(DEVICE) return policy diff --git a/tests/policies/pi0_pi05/test_pi0_pi05_hub.py b/tests/policies/pi0_pi05/test_pi0_pi05_hub.py index 25671d8ec..92e918422 100644 --- a/tests/policies/pi0_pi05/test_pi0_pi05_hub.py +++ b/tests/policies/pi0_pi05/test_pi0_pi05_hub.py @@ -18,8 +18,8 @@ pytestmark = pytest.mark.skipif( reason="This test requires HuggingFace authentication and is not meant for CI", ) -from lerobot.policies.pi0_openpi import PI0OpenPIPolicy # noqa: E402 -from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy # noqa: E402 +from lerobot.policies.pi0 import PI0Policy # noqa: E402 +from lerobot.policies.pi05.modeling_pi05openpi import PI05Policy # noqa: E402 def create_dummy_stats(config): @@ -48,13 +48,13 @@ def create_dummy_stats(config): # Test data for all 6 base models MODEL_TEST_PARAMS = [ # PI0 models - ("pepijn223/pi0_base_fp32", "PI0", PI0OpenPIPolicy), - ("pepijn223/pi0_droid_fp32", "PI0", PI0OpenPIPolicy), - ("pepijn223/pi0_libero_fp32", "PI0", PI0OpenPIPolicy), + ("pepijn223/pi0_base_fp32", "PI0", PI0Policy), + ("pepijn223/pi0_droid_fp32", "PI0", PI0Policy), + ("pepijn223/pi0_libero_fp32", "PI0", PI0Policy), # PI0.5 models - ("pepijn223/pi05_base_fp32", "PI0.5", PI05OpenPIPolicy), - ("pepijn223/pi05_droid_fp32", "PI0.5", PI05OpenPIPolicy), - ("pepijn223/pi05_libero_fp32", "PI0.5", PI05OpenPIPolicy), + ("pepijn223/pi05_base_fp32", "PI0.5", PI05Policy), + ("pepijn223/pi05_droid_fp32", "PI0.5", PI05Policy), + ("pepijn223/pi05_libero_fp32", "PI0.5", PI05Policy), ] @@ -65,7 +65,7 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class): Args: model_id: HuggingFace model ID (e.g., "pepijn223/pi0_base_fp32") model_type: Model type ("PI0" or "PI0.5") - policy_class: Policy class to use (PI0OpenPIPolicy or PI05OpenPIPolicy) + policy_class: Policy class to use (PI0Policy or PI05Policy) """ print(f"\n{'=' * 80}") print(f"Testing {model_type} model: {model_id}") diff --git a/tests/processor/test_pi0_processor.py b/tests/processor/test_pi0_processor.py deleted file mode 100644 index c481cb18f..000000000 --- a/tests/processor/test_pi0_processor.py +++ /dev/null @@ -1,424 +0,0 @@ -#!/usr/bin/env python - -# Copyright 2025 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for PI0 policy processor.""" - -from unittest.mock import patch - -import pytest -import torch - -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature -from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE -from lerobot.policies.pi0.configuration_pi0 import PI0Config -from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_pre_post_processors -from lerobot.processor import ( - AddBatchDimensionProcessorStep, - DeviceProcessorStep, - EnvTransition, - NormalizerProcessorStep, - ProcessorStep, - RenameObservationsProcessorStep, - TransitionKey, - UnnormalizerProcessorStep, -) -from lerobot.processor.converters import create_transition, transition_to_batch - - -class MockTokenizerProcessorStep(ProcessorStep): - """Mock tokenizer processor step for testing.""" - - def __init__(self, *args, **kwargs): - # Accept any arguments to mimic the real TokenizerProcessorStep interface - pass - - def __call__(self, transition: EnvTransition) -> EnvTransition: - # Pass through transition unchanged - return transition - - def transform_features(self, features): - # Pass through features unchanged - return features - - -def create_default_config(): - """Create a default PI0 configuration for testing.""" - config = PI0Config() - config.input_features = { - OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)), - OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)), - } - config.output_features = { - ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)), - } - config.normalization_mapping = { - FeatureType.STATE: NormalizationMode.MEAN_STD, - FeatureType.VISUAL: NormalizationMode.IDENTITY, - FeatureType.ACTION: NormalizationMode.MIN_MAX, - } - config.device = "cpu" - config.tokenizer_max_length = 128 - return config - - -def create_default_stats(): - """Create default dataset statistics for testing.""" - return { - OBS_STATE: {"mean": torch.zeros(10), "std": torch.ones(10)}, - OBS_IMAGE: {}, # No normalization for images - ACTION: {"min": torch.full((6,), -1.0), "max": torch.ones(6)}, - } - - -def test_make_pi0_processor_basic(): - """Test basic creation of PI0 processor.""" - config = create_default_config() - stats = create_default_stats() - - with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): - preprocessor, postprocessor = make_pi0_pre_post_processors( - config, - stats, - ) - - # Check processor names - assert preprocessor.name == "policy_preprocessor" - assert postprocessor.name == "policy_postprocessor" - - # Check steps in preprocessor - assert len(preprocessor.steps) == 6 - assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep) - assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) - assert isinstance(preprocessor.steps[2], Pi0NewLineProcessor) - # Step 3 would be TokenizerProcessorStep but it's mocked - assert isinstance(preprocessor.steps[4], DeviceProcessorStep) - assert isinstance(preprocessor.steps[5], NormalizerProcessorStep) - - # Check steps in postprocessor - assert len(postprocessor.steps) == 2 - assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep) - assert isinstance(postprocessor.steps[1], DeviceProcessorStep) - - -def test_pi0_newline_processor_single_task(): - """Test Pi0NewLineProcessor with single task string.""" - processor = Pi0NewLineProcessor() - - # Test with task that doesn't have newline - transition = create_transition(complementary_data={"task": "test task"}) - result = processor(transition) - assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n" - - # Test with task that already has newline - transition = create_transition(complementary_data={"task": "test task\n"}) - result = processor(transition) - assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n" - - -def test_pi0_newline_processor_list_of_tasks(): - """Test Pi0NewLineProcessor with list of task strings.""" - processor = Pi0NewLineProcessor() - - # Test with list of tasks - tasks = ["task1", "task2\n", "task3"] - transition = create_transition(complementary_data={"task": tasks}) - result = processor(transition) - expected = ["task1\n", "task2\n", "task3\n"] - assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == expected - - -def test_pi0_newline_processor_empty_transition(): - """Test Pi0NewLineProcessor with empty transition.""" - processor = Pi0NewLineProcessor() - - # Test with no complementary_data - transition = create_transition() - result = processor(transition) - assert result == transition - - # Test with complementary_data but no task - transition = create_transition(complementary_data={"other": "data"}) - result = processor(transition) - assert result == transition - - # Test with None task - transition = create_transition(complementary_data={"task": None}) - result = processor(transition) - assert result == transition - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_pi0_processor_cuda(): - """Test PI0 processor with CUDA device.""" - config = create_default_config() - config.device = "cuda" - stats = create_default_stats() - - # Mock the tokenizer processor to act as pass-through - class MockTokenizerProcessorStep(ProcessorStep): - def __init__(self, *args, **kwargs): - pass - - def __call__(self, transition): - return transition - - def state_dict(self): - return {} - - def load_state_dict(self, state): - pass - - def reset(self): - pass - - def get_config(self): - return {"tokenizer_name": "google/paligemma-3b-pt-224"} - - def transform_features(self, features): - return features - - with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): - preprocessor, postprocessor = make_pi0_pre_post_processors( - config, - stats, - ) - - # Create CPU data - observation = { - OBS_STATE: torch.randn(10), - OBS_IMAGE: torch.randn(3, 224, 224), - } - action = torch.randn(6) - transition = create_transition(observation, action, complementary_data={"task": "test task"}) - batch = transition_to_batch(transition) - - # Process through preprocessor - processed = preprocessor(batch) - - # Check that data is on CUDA - assert processed[OBS_STATE].device.type == "cuda" - assert processed[OBS_IMAGE].device.type == "cuda" - assert processed[TransitionKey.ACTION.value].device.type == "cuda" - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_pi0_processor_accelerate_scenario(): - """Test PI0 processor in simulated Accelerate scenario.""" - config = create_default_config() - config.device = "cuda:0" - stats = create_default_stats() - - # Mock the tokenizer processor to act as pass-through - class MockTokenizerProcessorStep(ProcessorStep): - def __init__(self, *args, **kwargs): - pass - - def __call__(self, transition): - return transition - - def state_dict(self): - return {} - - def load_state_dict(self, state): - pass - - def reset(self): - pass - - def get_config(self): - return {"tokenizer_name": "google/paligemma-3b-pt-224"} - - def transform_features(self, features): - return features - - with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): - preprocessor, postprocessor = make_pi0_pre_post_processors( - config, - stats, - ) - - # Simulate Accelerate: data already on GPU and batched - device = torch.device("cuda:0") - observation = { - OBS_STATE: torch.randn(1, 10).to(device), - OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), - } - action = torch.randn(1, 6).to(device) - transition = create_transition(observation, action, complementary_data={"task": ["test task"]}) - batch = transition_to_batch(transition) - - # Process through preprocessor - processed = preprocessor(batch) - - # Check that data stays on same GPU - assert processed[OBS_STATE].device == device - assert processed[OBS_IMAGE].device == device - assert processed[TransitionKey.ACTION.value].device == device - - -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs") -def test_pi0_processor_multi_gpu(): - """Test PI0 processor with multi-GPU setup.""" - config = create_default_config() - config.device = "cuda:0" - stats = create_default_stats() - - # Mock the tokenizer processor to act as pass-through - class MockTokenizerProcessorStep(ProcessorStep): - def __init__(self, *args, **kwargs): - pass - - def __call__(self, transition): - return transition - - def state_dict(self): - return {} - - def load_state_dict(self, state): - pass - - def reset(self): - pass - - def get_config(self): - return {"tokenizer_name": "google/paligemma-3b-pt-224"} - - def transform_features(self, features): - return features - - with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): - preprocessor, postprocessor = make_pi0_pre_post_processors( - config, - stats, - ) - - # Simulate data on different GPU - device = torch.device("cuda:1") - observation = { - OBS_STATE: torch.randn(1, 10).to(device), - OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device), - } - action = torch.randn(1, 6).to(device) - transition = create_transition(observation, action, complementary_data={"task": ["test task"]}) - batch = transition_to_batch(transition) - - # Process through preprocessor - processed = preprocessor(batch) - - # Check that data stays on cuda:1 - assert processed[OBS_STATE].device == device - assert processed[OBS_IMAGE].device == device - assert processed[TransitionKey.ACTION.value].device == device - - -def test_pi0_processor_without_stats(): - """Test PI0 processor creation without dataset statistics.""" - config = create_default_config() - - # Mock the tokenizer processor - with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): - preprocessor, postprocessor = make_pi0_pre_post_processors( - config, - dataset_stats=None, - ) - - # Should still create processors - assert preprocessor is not None - assert postprocessor is not None - - -def test_pi0_newline_processor_state_dict(): - """Test Pi0NewLineProcessor state dict methods.""" - processor = Pi0NewLineProcessor() - - # Test state_dict (should be empty) - state = processor.state_dict() - assert state == {} - - # Test load_state_dict (should do nothing) - processor.load_state_dict({}) - - # Test reset (should do nothing) - processor.reset() - - # Test get_config - config = processor.get_config() - assert config == {} - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_pi0_processor_bfloat16_device_float32_normalizer(): - """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" - config = create_default_config() - stats = create_default_stats() - config.device = "cuda" - - with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): - preprocessor, _ = make_pi0_pre_post_processors( - config, - stats, - ) - - # Modify the pipeline to use bfloat16 device processor with float32 normalizer - modified_steps = [] - for step in preprocessor.steps: - if isinstance(step, DeviceProcessorStep): - # Device processor converts to bfloat16 - modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) - elif isinstance(step, NormalizerProcessorStep): - # Normalizer stays configured as float32 (will auto-adapt to bfloat16) - norm_step = step # Now type checker knows this is NormalizerProcessorStep - modified_steps.append( - NormalizerProcessorStep( - features=norm_step.features, - norm_map=norm_step.norm_map, - stats=norm_step.stats, - device=config.device, - dtype=torch.float32, # Deliberately configured as float32 - ) - ) - else: - modified_steps.append(step) - preprocessor.steps = modified_steps - - # Verify initial normalizer configuration (PI0 has NormalizerProcessorStep at index 5) - normalizer_step = preprocessor.steps[5] # NormalizerProcessorStep - assert normalizer_step.dtype == torch.float32 - - # Create test data with both state and visual observations - observation = { - OBS_STATE: torch.randn(10, dtype=torch.float32), # PI0 expects size 10 - OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), - } - action = torch.randn(6, dtype=torch.float32) # PI0 expects size 6 - transition = create_transition( - observation, action, complementary_data={"task": "test bfloat16 adaptation"} - ) - batch = transition_to_batch(transition) - - # Process through full pipeline - processed = preprocessor(batch) - - # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 - assert processed[OBS_STATE].dtype == torch.bfloat16 - assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion - assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16 - - # Verify normalizer automatically adapted its internal state - assert normalizer_step.dtype == torch.bfloat16 - # Check state stats (has normalization) - for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): - assert stat_tensor.dtype == torch.bfloat16 - # OBS_IMAGE uses IDENTITY normalization, so no stats to check