Remove previous pi0 and rename pi0_openpi and pi05_openpi

This commit is contained in:
Pepijn
2025-09-22 17:11:29 +02:00
parent 83ed49d9b9
commit 5d9acf9d51
41 changed files with 54 additions and 2813 deletions
+4
View File
@@ -124,3 +124,7 @@ python src/lerobot/scripts/train.py \
LeRobot uses MuJoCo for simulation. You need to set the rendering backend before training or evaluation:
- `export MUJOCO_GL=egl` → for headless servers (e.g. HPC, cloud)
## Reproducing π₀ and π₀.₅ results
We can also reproduce the results of π₀ and π₀.₅ on the Libero benchmark by using the finetuned libero models.
+3 -3
View File
@@ -35,7 +35,7 @@ As described by Physical Intelligence, while AI has achieved remarkable success
2. Apply the custom patches:
```bash
cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* \
cp -r ./src/lerobot/policies/pi0/transformers_replace/* \
$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")
```
@@ -72,7 +72,7 @@ pip install transformers==4.53.2
To use π₀ in LeRobot, specify the policy type as:
```python
policy.type=pi0_openpi
policy.type=pi0
```
## Training
@@ -82,7 +82,7 @@ For training π₀, you can use the standard LeRobot training script with the ap
```bash
python src/lerobot/scripts/train.py \
--dataset.repo_id=your_dataset \
--policy.type=pi0_openpi \
--policy.type=pi0 \
--output_dir=./outputs/pi0_training \
--job_name=pi0_training \
--policy.pretrained_path=pepijn223/pi0_base_fp32 \
+3 -3
View File
@@ -43,7 +43,7 @@ This diverse training mixture creates a "curriculum" that enables generalization
2. Apply the custom patches:
```bash
cp -r ./src/lerobot/policies/pi05_openpi/transformers_replace/* \
cp -r ./src/lerobot/policies/pi05/transformers_replace/* \
$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")
```
@@ -72,7 +72,7 @@ pip install transformers==4.53.2
To use π₀.₅ in your LeRobot configuration, specify the policy type as:
```python
policy.type=pi05_openpi
policy.type=pi05
```
## Training
@@ -84,7 +84,7 @@ Here's a complete training command for finetuning the base π₀.₅ model on yo
```bash
python src/lerobot/scripts/train.py \
--dataset.repo_id=your_dataset \
--policy.type=pi05_openpi \
--policy.type=pi05 \
--output_dir=./outputs/pi0_training \
--job_name=pi0_training \
--policy.repo_id=pepijn223/pi05_base_fp32 \
+1 -2
View File
@@ -147,8 +147,7 @@ all = [
"lerobot[reachy2]",
"lerobot[kinematics]",
"lerobot[intelrealsense]",
"lerobot[pi0]",
"lerobot[pi05]",
"lerobot[pi]",
"lerobot[smolvla]",
"lerobot[hilserl]",
"lerobot[async]",
+4 -5
View File
@@ -14,10 +14,8 @@
from .act.configuration_act import ACTConfig as ACTConfig
from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig
from .pi0.configuration_pi0 import PI0Config as PI0Config
from .pi0.processor_pi0 import Pi0NewLineProcessor
from .pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig as PI0OpenPIConfig
from .pi05_openpi.configuration_pi05openpi import PI05OpenPIConfig as PI05OpenPIConfig
from .pi0.configuration_pi0openpi import PI0OpenPIConfig as PI0OpenPIConfig
from .pi05.configuration_pi05openpi import PI05OpenPIConfig as PI05OpenPIConfig
from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
@@ -26,7 +24,8 @@ from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
__all__ = [
"ACTConfig",
"DiffusionConfig",
"PI0Config",
"PI0OpenPIConfig",
"PI05OpenPIConfig",
"SmolVLAConfig",
"TDMPCConfig",
"VQBeTConfig",
+8 -23
View File
@@ -31,10 +31,9 @@ from lerobot.envs.configs import EnvConfig
from lerobot.envs.utils import env_to_policy_features
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig
from lerobot.policies.pi0.configuration_pi0openpi import PI0OpenPIConfig
from lerobot.policies.pi0fast.configuration_pi0fast import PI0FASTConfig
from lerobot.policies.pi05_openpi.configuration_pi05openpi import PI05OpenPIConfig
from lerobot.policies.pi05.configuration_pi05openpi import PI05OpenPIConfig
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
@@ -83,20 +82,16 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
return VQBeTPolicy
elif name == "pi0":
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
return PI0Policy
elif name == "pi0fast":
from lerobot.policies.pi0fast.modeling_pi0fast import PI0FASTPolicy
return PI0FASTPolicy
elif name == "pi0_openpi":
from lerobot.policies.pi0_openpi.modeling_pi0openpi import PI0OpenPIPolicy
elif name == "pi0":
from lerobot.policies.pi0.modeling_pi0openpi import PI0OpenPIPolicy
return PI0OpenPIPolicy
elif name == "pi05_openpi":
from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy
elif name == "pi05":
from lerobot.policies.pi05.modeling_pi05openpi import PI05OpenPIPolicy
return PI05OpenPIPolicy
elif name == "sac":
@@ -142,13 +137,11 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
return ACTConfig(**kwargs)
elif policy_type == "vqbet":
return VQBeTConfig(**kwargs)
elif policy_type == "pi0":
return PI0Config(**kwargs)
elif policy_type == "pi0fast":
return PI0FASTConfig(**kwargs)
elif policy_type == "pi0_openpi":
elif policy_type == "pi0":
return PI0OpenPIConfig(**kwargs)
elif policy_type == "pi05_openpi":
elif policy_type == "pi05":
return PI05OpenPIConfig(**kwargs)
elif policy_type == "sac":
return SACConfig(**kwargs)
@@ -267,14 +260,6 @@ def make_pre_post_processors(
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, PI0Config):
from lerobot.policies.pi0.processor_pi0 import make_pi0_pre_post_processors
processors = make_pi0_pre_post_processors(
config=policy_cfg,
dataset_stats=kwargs.get("dataset_stats"),
)
elif isinstance(policy_cfg, PI0FASTConfig):
from lerobot.policies.pi0fast.processor_pi0fast import make_pi0fast_pre_post_processors
@@ -1,149 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import (
CosineDecayWithWarmupSchedulerConfig,
)
@PreTrainedConfig.register_subclass("pi0")
@dataclass
class PI0Config(PreTrainedConfig):
# Input / output structure.
n_obs_steps: int = 1
chunk_size: int = 50
n_action_steps: int = 50
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.IDENTITY,
"STATE": NormalizationMode.MEAN_STD,
"ACTION": NormalizationMode.MEAN_STD,
}
)
# Shorter state and action vectors will be padded
max_state_dim: int = 32
max_action_dim: int = 32
# Image preprocessing
resize_imgs_with_padding: tuple[int, int] = (224, 224)
# Add empty images. Used by pi0_aloha_sim which adds the empty
# left and right wrist cameras in addition to the top camera.
empty_cameras: int = 0
# Converts the joint and gripper values from the standard Aloha space to
# the space used by the pi internal runtime which was used to train the base model.
adapt_to_pi_aloha: bool = False
# Converts joint dimensions to deltas with respect to the current state before passing to the model.
# Gripper dimensions will remain in absolute values.
use_delta_joint_actions_aloha: bool = False
# Tokenizer
tokenizer_max_length: int = 48
# Projector
proj_width: int = 1024
# Decoding
num_steps: int = 10
# Attention utils
use_cache: bool = True
attention_implementation: str = "eager" # or fa2, flex
# Finetuning settings
freeze_vision_encoder: bool = True
train_expert_only: bool = False
train_state_proj: bool = True
# Training presets
optimizer_lr: float = 2.5e-5
optimizer_betas: tuple[float, float] = (0.9, 0.95)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-10
scheduler_warmup_steps: int = 1_000
scheduler_decay_steps: int = 30_000
scheduler_decay_lr: float = 2.5e-6
# TODO: Add EMA
def __post_init__(self):
super().__post_init__()
# TODO(Steven): Validate device and amp? in all policy configs?
"""Input validation (not exhaustive)."""
if self.n_action_steps > self.chunk_size:
raise ValueError(
f"The chunk size is the upper bound for the number of action steps per model invocation. Got "
f"{self.n_action_steps} for `n_action_steps` and {self.chunk_size} for `chunk_size`."
)
if self.n_obs_steps != 1:
raise ValueError(
f"Multiple observation steps not handled yet. Got `nobs_steps={self.n_obs_steps}`"
)
if self.use_delta_joint_actions_aloha:
raise NotImplementedError(
"`use_delta_joint_actions_aloha` is used by pi0 for aloha real models. It is not ported yet in LeRobot."
)
def validate_features(self) -> None:
# TODO: implement value error
# if not self.image_features and not self.env_state_feature:
# raise ValueError("You must provide at least one image or the environment state among the inputs.")
for i in range(self.empty_cameras):
key = f"observation.images.empty_camera_{i}"
empty_camera = PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 480, 640),
)
self.input_features[key] = empty_camera
def get_optimizer_preset(self) -> AdamWConfig:
return AdamWConfig(
lr=self.optimizer_lr,
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
)
def get_scheduler_preset(self):
return CosineDecayWithWarmupSchedulerConfig(
peak_lr=self.optimizer_lr,
decay_lr=self.scheduler_decay_lr,
num_warmup_steps=self.scheduler_warmup_steps,
num_decay_steps=self.scheduler_decay_steps,
)
@property
def observation_delta_indices(self) -> None:
return None
@property
def action_delta_indices(self) -> list:
return list(range(self.chunk_size))
@property
def reward_delta_indices(self) -> None:
return None
@@ -22,7 +22,7 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("pi0_openpi")
@PreTrainedConfig.register_subclass("pi0")
@dataclass
class PI0OpenPIConfig(PreTrainedConfig):
# Model architecture
@@ -1,82 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.factory import make_policy
torch.backends.cudnn.benchmark = True
def main():
device = "cuda"
dataset_repo_id = "danaaubakirova/koch_test"
# model_name = "pi0_base"
# ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
ckpt_torch_dir = "lerobot/pi0"
dataset = LeRobotDataset(dataset_repo_id, episodes=[0])
dataloader = torch.utils.data.DataLoader(
dataset,
num_workers=0,
batch_size=1,
)
batch = next(iter(dataloader))
# To device
for k in batch:
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].to(device=device, dtype=torch.float32)
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
cfg.pretrained_path = ckpt_torch_dir
policy = make_policy(cfg, ds_meta=dataset.meta)
# policy = torch.compile(policy, mode="reduce-overhead")
warmup_iters = 10
benchmark_iters = 30
# Warmup
for _ in range(warmup_iters):
torch.cuda.synchronize()
policy.select_action(batch)
policy.reset()
torch.cuda.synchronize()
# Benchmark
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for _ in range(benchmark_iters):
policy.select_action(batch)
policy.reset()
end_event.record()
# Synchronize and measure time
torch.cuda.synchronize()
elapsed_time_ms = start_event.elapsed_time(end_event)
avg_time_per_iter = elapsed_time_ms / benchmark_iters
print(f"Average execution time per iteration: {avg_time_per_iter:.3f} ms")
if __name__ == "__main__":
with torch.inference_mode():
main()
@@ -1,131 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import pickle
from pathlib import Path
import torch
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
from lerobot.policies.factory import make_policy
def display(tensor: torch.Tensor):
if tensor.dtype == torch.bool:
tensor = tensor.float()
print(f"Shape: {tensor.shape}")
print(f"Mean: {tensor.mean().item()}")
print(f"Std: {tensor.std().item()}")
print(f"Min: {tensor.min().item()}")
print(f"Max: {tensor.max().item()}")
def main():
num_motors = 14
device = "cuda"
# model_name = "pi0_aloha_towel"
model_name = "pi0_aloha_sim"
if model_name == "pi0_aloha_towel":
dataset_repo_id = "lerobot/aloha_static_towel"
else:
dataset_repo_id = "lerobot/aloha_sim_transfer_cube_human"
ckpt_torch_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}_pytorch"
ckpt_jax_dir = Path.home() / f".cache/openpi/openpi-assets/checkpoints/{model_name}"
save_dir = Path(f"../openpi/data/{model_name}/save")
with open(save_dir / "example.pkl", "rb") as f:
example = pickle.load(f)
with open(save_dir / "outputs.pkl", "rb") as f:
outputs = pickle.load(f)
with open(save_dir / "noise.pkl", "rb") as f:
noise = pickle.load(f)
with open(ckpt_jax_dir / "assets/norm_stats.json") as f:
norm_stats = json.load(f)
# Override stats
dataset_meta = LeRobotDatasetMetadata(dataset_repo_id)
dataset_meta.stats["observation.state"]["mean"] = torch.tensor(
norm_stats["norm_stats"]["state"]["mean"][:num_motors], dtype=torch.float32
)
dataset_meta.stats["observation.state"]["std"] = torch.tensor(
norm_stats["norm_stats"]["state"]["std"][:num_motors], dtype=torch.float32
)
# Create LeRobot batch from Jax
batch = {}
for cam_key, uint_chw_array in example["images"].items():
batch[f"observation.images.{cam_key}"] = torch.from_numpy(uint_chw_array) / 255.0
batch["observation.state"] = torch.from_numpy(example["state"])
batch["action"] = torch.from_numpy(outputs["actions"])
batch["task"] = example["prompt"]
if model_name == "pi0_aloha_towel":
del batch["observation.images.cam_low"]
elif model_name == "pi0_aloha_sim":
batch["observation.images.top"] = batch["observation.images.cam_high"]
del batch["observation.images.cam_high"]
# Batchify
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].unsqueeze(0)
elif isinstance(batch[key], str):
batch[key] = [batch[key]]
else:
raise ValueError(f"{key}, {batch[key]}")
# To device
for k in batch:
if isinstance(batch[k], torch.Tensor):
batch[k] = batch[k].to(device=device, dtype=torch.float32)
noise = torch.from_numpy(noise).to(device=device, dtype=torch.float32)
from lerobot import policies # noqa
cfg = PreTrainedConfig.from_pretrained(ckpt_torch_dir)
cfg.pretrained_path = ckpt_torch_dir
policy = make_policy(cfg, dataset_meta)
# loss_dict = policy.forward(batch, noise=noise, time=time_beta)
# loss_dict["loss"].backward()
# print("losses")
# display(loss_dict["losses_after_forward"])
# print("pi_losses")
# display(pi_losses)
actions = []
for _ in range(50):
action = policy.select_action(batch, noise=noise)
actions.append(action)
actions = torch.stack(actions, dim=1)
pi_actions = batch["action"]
print("actions")
display(actions)
print()
print("pi_actions")
display(pi_actions)
print("atol=3e-2", torch.allclose(actions, pi_actions, atol=3e-2))
print("atol=2e-2", torch.allclose(actions, pi_actions, atol=2e-2))
print("atol=1e-2", torch.allclose(actions, pi_actions, atol=1e-2))
if __name__ == "__main__":
main()
@@ -1,84 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from transformers import GemmaConfig, PaliGemmaConfig
def get_paligemma_config(precision: str):
config = {
"image_token_index": None,
"pad_token_id": 0,
"bos_token_id": 2,
"eos_token_id": 1,
}
# image_sizes = {"2b-test": 224, "3b-224px": 224, "3b-448px": 448, "3b-896px": 896}
image_size = 224 # image_sizes[variant]
patch_size = 14
num_image_tokens = (image_size**2) // (patch_size**2)
config["image_token_index"] = 257152
text_config = {
"vocab_size": 257152,
"num_hidden_layers": 18,
"num_key_value_heads": 1,
"head_dim": 256,
"torch_dtype": precision,
"hidden_size": 2048,
"hidden_activation": "gelu_pytorch_tanh",
"num_attention_heads": 8,
"intermediate_size": 16384,
"is_encoder_decoder": False,
}
vision_config = {
"torch_dtype": precision,
"image_size": image_size,
"patch_size": patch_size,
"num_image_tokens": num_image_tokens,
"hidden_size": 1152,
"intermediate_size": 4304,
"num_hidden_layers": 27,
"num_attention_heads": 16,
"projector_hidden_act": "gelu_fast",
"vision_use_head": False,
}
final_config = PaliGemmaConfig(text_config=text_config, vision_config=vision_config, **config)
return final_config
def get_gemma_config(precision: str):
config = {
"image_token_index": None,
"pad_token_id": 0,
"bos_token_id": 2,
"eos_token_id": 1,
}
config["image_token_index"] = 257152
text_config = {
"vocab_size": 257152,
"num_hidden_layers": 18,
"num_key_value_heads": 1,
"head_dim": 256,
"torch_dtype": precision,
"hidden_size": 1024,
"hidden_activation": "gelu_pytorch_tanh",
"num_attention_heads": 8,
"intermediate_size": 4096,
"is_encoder_decoder": False,
}
final_config = GemmaConfig()
final_config.update(text_config)
return final_config
@@ -1,437 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert pi0 parameters from Jax to Pytorch
Follow [README of openpi](https://github.com/Physical-Intelligence/openpi) to create a new environment
and install the required libraries.
```bash
cd ~/code/openpi
source .venv/bin/activate
```
Example downloading parameters:
```bash
python
>>> import openpi.shared.download as download
>>> path='s3://openpi-assets/checkpoints/pi0_base/params'
>>> download.maybe_download(path)
```
Converting pi0_base:
```python
python -m lerobot.policies.pi0.conversion_scripts.convert_pi0_to_hf_lerobot \
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base/params \
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_base_pytorch
```
```python
python -m lerobot.policies.pi0.conversion_scripts.convert_pi0_to_hf_lerobot \
--checkpoint_dir /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params \
--output_path /home/remi_cadene/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim_pytorch
```
"""
import argparse
import pathlib
import jax
import numpy as np
import orbax.checkpoint as ocp
import torch
from jax.sharding import SingleDeviceSharding
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0.conversion_scripts.conversion_utils import (
get_gemma_config,
get_paligemma_config,
)
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
PRECISIONS = {"bfloat16": torch.bfloat16, "float32": torch.float32, "float16": torch.float16}
def slice_paligemma_state_dict(state_dict, config):
suffix = "/value" if "img/embedding/kernel/value" in state_dict else ""
# fmt: off
# patch embeddings
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.weight"] = state_dict.pop(f"img/embedding/kernel{suffix}").transpose(
3, 2, 0, 1
)
state_dict["paligemma.vision_tower.vision_model.embeddings.patch_embedding.bias"] = state_dict.pop(f"img/embedding/bias{suffix}")
# positional embeddings
state_dict["paligemma.vision_tower.vision_model.embeddings.position_embedding.weight"] = state_dict.pop(f"img/pos_embedding{suffix}").reshape(
-1, config.vision_config.hidden_size
)
# extract vision layers to be sliced at index 0. There are 27 layers in the base model.
encoderblock_layernorm0_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/scale{suffix}")
encoderblock_layernorm0_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_0/bias{suffix}")
encoderblock_layernorm1_scale = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/scale{suffix}")
encoderblock_layernorm1_bias = state_dict.pop(f"img/Transformer/encoderblock/LayerNorm_1/bias{suffix}")
encoderblock_mlp_dense0_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/kernel{suffix}")
encoderblock_mlp_dense0_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_0/bias{suffix}")
encoderblock_mlp_dense1_kernel= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/kernel{suffix}")
encoderblock_mlp_dense1_bias= state_dict.pop(f"img/Transformer/encoderblock/MlpBlock_0/Dense_1/bias{suffix}")
encoderblock_attention_0_key_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/kernel{suffix}")
encoderblock_attention_0_key_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/key/bias{suffix}")
encoderblock_attention_0_value_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/kernel{suffix}")
encoderblock_attention_0_value_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/value/bias{suffix}")
encoderblock_attention_0_query_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/kernel{suffix}")
encoderblock_attention_0_query_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/query/bias{suffix}")
encoderblock_attention_0_out_kernel = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/kernel{suffix}")
encoderblock_attention_0_out_bias = state_dict.pop(f"img/Transformer/encoderblock/MultiHeadDotProductAttention_0/out/bias{suffix}")
for i in range(config.vision_config.num_hidden_layers):
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.weight"] = encoderblock_layernorm0_scale[i].transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm1.bias"] = encoderblock_layernorm0_bias[i]
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.weight"] = encoderblock_layernorm1_scale[i].transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.layer_norm2.bias"] = encoderblock_layernorm1_bias[i]
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.weight"] = encoderblock_mlp_dense0_kernel[i].transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc1.bias"] = encoderblock_mlp_dense0_bias[i]
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.weight"] = encoderblock_mlp_dense1_kernel[i].transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.mlp.fc2.bias"] = encoderblock_mlp_dense1_bias[i]
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.weight"] = encoderblock_attention_0_key_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.k_proj.bias"] = encoderblock_attention_0_key_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.weight"] = encoderblock_attention_0_value_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.v_proj.bias"] = encoderblock_attention_0_value_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.weight"] = encoderblock_attention_0_query_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.q_proj.bias"] = encoderblock_attention_0_query_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.weight"] = encoderblock_attention_0_out_kernel[i].reshape(-1, config.vision_config.hidden_size).transpose()
state_dict[f"paligemma.vision_tower.vision_model.encoder.layers.{i}.self_attn.out_proj.bias"] = encoderblock_attention_0_out_bias[i].reshape(-1, config.vision_config.hidden_size).reshape(-1)
state_dict["paligemma.vision_tower.vision_model.post_layernorm.weight"] = state_dict.pop(f"img/Transformer/encoder_norm/scale{suffix}").transpose()
state_dict["paligemma.vision_tower.vision_model.post_layernorm.bias"] = state_dict.pop(f"img/Transformer/encoder_norm/bias{suffix}")
# multimodal projector
state_dict['paligemma.multi_modal_projector.linear.weight'] = state_dict.pop(f"img/head/kernel{suffix}").transpose()
state_dict['paligemma.multi_modal_projector.linear.bias'] = state_dict.pop(f"img/head/bias{suffix}")
# text decoder (gemma)
embedding_vector = state_dict.pop(f"llm/embedder/input_embedding{suffix}")
state_dict["paligemma.language_model.model.embed_tokens.weight"] = embedding_vector
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum/w{suffix}")
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum/w{suffix}")
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum/w{suffix}")
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp/gating_einsum{suffix}")
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp/linear{suffix}")
# TODO verify correctness of layer norm loading
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm/scale{suffix}")
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm/scale{suffix}")
for i in range(config.text_config.num_hidden_layers):
# llm_attention_q_einsum[i].shape = (8, 2048, 256)
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
# llm_attention_kv_einsum[i, 0, 0].shape = (2048, 256)
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
# llm_attention_kv_einsum[i, 1, 0].shape = (2048, 256)
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
# output projection.
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 2048)
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].transpose(2, 0, 1).reshape(config.text_config.num_attention_heads * config.text_config.head_dim, config.text_config.hidden_size)
state_dict[f"paligemma.language_model.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
# mlp layers
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
up_proj_weight = llm_mlp_gating_einsum[i, 1]
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
state_dict[f"paligemma.language_model.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
state_dict[f"paligemma.language_model.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
state_dict[f"paligemma.language_model.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
state_dict["paligemma.language_model.model.norm.weight"] = state_dict.pop(f"llm/final_norm/scale{suffix}")
state_dict["paligemma.language_model.lm_head.weight"] = embedding_vector # weights are tied.
# fmt: on
expert_dict = {}
final_state_dict = {}
for key, value in state_dict.items():
if key not in [
f"llm/final_norm_1/scale{suffix}",
f"llm/layers/attn/attn_vec_einsum_1/w{suffix}",
f"llm/layers/attn/kv_einsum_1/w{suffix}",
f"llm/layers/attn/q_einsum_1/w{suffix}",
f"llm/layers/mlp_1/gating_einsum{suffix}",
f"llm/layers/mlp_1/linear{suffix}",
f"llm/layers/pre_attention_norm_1/scale{suffix}",
f"llm/layers/pre_ffw_norm_1/scale{suffix}",
]:
final_state_dict[key] = torch.from_numpy(value)
else:
expert_dict[key] = value
return final_state_dict, expert_dict
def slice_gemma_state_dict(state_dict, config, num_expert=1):
# fmt: off
# text decoder (gemma)
# no embedding vector, the expert just has the decoder layers
embedding_vector = torch.zeros([config.vocab_size, config.hidden_size])
state_dict["gemma_expert.model.embed_tokens.weight"] = embedding_vector
# pop the einsum attention + mlp representations. There are 18 layers in gemma-2b.
suffix = "/value" if f"llm/layers/attn/attn_vec_einsum_{num_expert}/w/value" in state_dict else ""
llm_attention_attn_vec_einsum = state_dict.pop(f"llm/layers/attn/attn_vec_einsum_{num_expert}/w{suffix}")
llm_attention_kv_einsum = state_dict.pop(f"llm/layers/attn/kv_einsum_{num_expert}/w{suffix}")
llm_attention_q_einsum = state_dict.pop(f"llm/layers/attn/q_einsum_{num_expert}/w{suffix}")
llm_mlp_gating_einsum = state_dict.pop(f"llm/layers/mlp_{num_expert}/gating_einsum{suffix}")
llm_mlp_linear = state_dict.pop(f"llm/layers/mlp_{num_expert}/linear{suffix}")
# TODO verify correctness of layer norm loading
llm_input_layernorm = state_dict.pop(f"llm/layers/pre_attention_norm_{num_expert}/scale{suffix}")
llm_post_attention_layernorm = state_dict.pop(f"llm/layers/pre_ffw_norm_{num_expert}/scale{suffix}")
for i in range(config.num_hidden_layers):
q_proj_weight_reshaped = llm_attention_q_einsum[i].transpose(0, 2, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size)
state_dict[f"gemma_expert.model.layers.{i}.self_attn.q_proj.weight"] = q_proj_weight_reshaped
k_proj_weight_reshaped = llm_attention_kv_einsum[i, 0, 0].transpose()
state_dict[f"gemma_expert.model.layers.{i}.self_attn.k_proj.weight"] = k_proj_weight_reshaped
v_proj_weight_reshaped = llm_attention_kv_einsum[i, 1, 0].transpose()
state_dict[f"gemma_expert.model.layers.{i}.self_attn.v_proj.weight"] = v_proj_weight_reshaped
# output projection.
# llm_attention_attn_vec_einsum[i].shape = (8, 256, 1024)
o_proj_weight_reshaped = llm_attention_attn_vec_einsum[i].reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1,0)# .transpose(2, 0, 1).reshape(config.num_attention_heads * config.head_dim, config.hidden_size).transpose(1, 0)
state_dict[f"gemma_expert.model.layers.{i}.self_attn.o_proj.weight"] = o_proj_weight_reshaped
# mlp layers
gate_proj_weight = llm_mlp_gating_einsum[i, 0]
state_dict[f"gemma_expert.model.layers.{i}.mlp.gate_proj.weight"] = gate_proj_weight.transpose()
up_proj_weight = llm_mlp_gating_einsum[i, 1]
state_dict[f"gemma_expert.model.layers.{i}.mlp.up_proj.weight"] = up_proj_weight.transpose()
state_dict[f"gemma_expert.model.layers.{i}.mlp.down_proj.weight"] = llm_mlp_linear[i].transpose()
state_dict[f"gemma_expert.model.layers.{i}.input_layernorm.weight"] = llm_input_layernorm[i]
state_dict[f"gemma_expert.model.layers.{i}.post_attention_layernorm.weight"] = llm_post_attention_layernorm[i]
state_dict["gemma_expert.model.norm.weight"] = state_dict.pop(f"llm/final_norm_{num_expert}/scale{suffix}")
state_dict["gemma_expert.lm_head.weight"] = embedding_vector # weights are tied. (and zeros here)
# fmt: on
final_state_dict = {}
for key, value in state_dict.items():
if not isinstance(value, torch.Tensor):
final_state_dict[key] = torch.from_numpy(value)
else:
final_state_dict[key] = value
return final_state_dict
def flatten_for_memory(tree, parent_key=""):
out = {}
for k, v in tree.items():
new_key = f"{parent_key}/{k}" if parent_key else k
if isinstance(v, dict):
out.update(flatten_for_memory(v, new_key))
else:
out[new_key] = np.array(v) # Ensure conversion to np.array for consistency
return out
def flatten_for_npz(tree, parent_key=""):
out = {}
for k, v in tree.items():
new_key = f"{parent_key}/{k}" if parent_key else k
if isinstance(v, dict):
out.update(flatten_for_npz(v, new_key))
else:
# bf16/f32 here?
out[new_key] = np.array(v)
return out
def slice_initial_orbax_checkpoint(checkpoint_dir: str):
params_path = pathlib.Path(checkpoint_dir).resolve()
checkpointer = ocp.PyTreeCheckpointer()
metadata = checkpointer.metadata(params_path)
print("Metadata keys:", list(metadata.keys()))
params_name = "params"
item = {params_name: metadata[params_name]}
device = jax.local_devices()[0] # Use the first local device
sharding = SingleDeviceSharding(device)
restored = checkpointer.restore(
params_path,
ocp.args.PyTreeRestore(
item=item,
restore_args=jax.tree_util.tree_map(
lambda _: ocp.ArrayRestoreArgs(
restore_type=jax.Array, # or np.ndarray, but bf16 is annoying about it
sharding=sharding,
),
item,
),
transforms={},
),
)
params = restored[params_name]
# get params for PaliGemma
pali_params = params["PaliGemma"]
del params["PaliGemma"]
pali_params_flat = flatten_for_npz(pali_params)
return {"paligemma_params": pali_params_flat, "projection_params": params}
def update_keys_with_prefix(d: dict, prefix: str) -> dict:
"""Update dictionary keys by adding a prefix."""
return {f"{prefix}{key}": value for key, value in d.items()}
def convert_pi0_checkpoint(checkpoint_dir: str, precision: str, tokenizer_id: str, output_path: str):
# Break down orbax ckpts - they are in OCDBT
initial_params = slice_initial_orbax_checkpoint(checkpoint_dir=checkpoint_dir)
# process projection params
keys = [
"state_proj",
"action_in_proj",
"action_out_proj",
"action_time_mlp_in",
"action_time_mlp_out",
]
projection_params = {}
for key in keys:
kernel_params = initial_params["projection_params"][key]["kernel"]
bias_params = initial_params["projection_params"][key]["bias"]
if isinstance(kernel_params, dict):
weight = kernel_params["value"]
bias = bias_params["value"]
else:
weight = kernel_params
bias = bias_params
projection_params[f"{key}.weight"] = torch.from_numpy(np.array(weight)).T
projection_params[f"{key}.bias"] = torch.from_numpy(np.array(bias))
# Process PaliGemma weights
paligemma_config = get_paligemma_config(precision)
paligemma_params, gemma_raw_dictionary = slice_paligemma_state_dict(
initial_params["paligemma_params"], paligemma_config
)
# Process Gemma weights (at this stage they are unused)
gemma_config = get_gemma_config(precision)
gemma_params = slice_gemma_state_dict(gemma_raw_dictionary, config=gemma_config)
# Instantiate model from configs
if "pi0_aloha_sim" in checkpoint_dir:
pi0_config = PI0Config(
empty_cameras=2,
adapt_to_pi_aloha=True,
use_delta_joint_actions_aloha=False,
)
elif "pi0_aloha_towel" in checkpoint_dir:
pi0_config = PI0Config(
adapt_to_pi_aloha=True,
use_delta_joint_actions_aloha=True,
)
elif "pi0_base" in checkpoint_dir:
pi0_config = PI0Config(
empty_cameras=0,
adapt_to_pi_aloha=False,
use_delta_joint_actions_aloha=False,
)
else:
raise ValueError()
# gemma_config=gemma_config, paligemma_config=paligemma_config)
pi0_model = PI0Policy(pi0_config)
paligemma_params = update_keys_with_prefix(paligemma_params, "model.paligemma_with_expert.")
gemma_params = update_keys_with_prefix(gemma_params, "model.paligemma_with_expert.")
projection_params = update_keys_with_prefix(projection_params, "model.")
# load state dict
torch_dtype = PRECISIONS[precision]
pi0_model.load_state_dict({**paligemma_params, **gemma_params, **projection_params})
pi0_model = pi0_model.to(torch_dtype)
# pi0_tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
pi0_model.save_pretrained(output_path, safe_serialization=True)
# pi0_tokenizer.save_pretrained(output_path, dtype=torch_dtype)
# assert that model loads properly
del pi0_model
PI0Policy.from_pretrained(output_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--checkpoint_dir",
default="/raid/pablo/.cache/openpi/openpi-assets/checkpoints/pi0_aloha_sim/params",
type=str,
help="Path to the ocdbt checkpoint",
)
parser.add_argument(
"--precision",
choices=["float32", "bfloat16", "float16"],
default="float32",
type=str,
help="Precision identifier for model conversion - should match the base checkpoint precision.",
)
# tokenizer is identical to paligemma, it appears
parser.add_argument(
"--tokenizer_hub_id",
default="google/paligemma-3b-pt-224",
type=str,
help="Hub path to the tokenizer to save",
)
parser.add_argument(
"--output_path",
required=True,
type=str,
help="Path to save converted weights to",
)
args = parser.parse_args()
convert_pi0_checkpoint(
checkpoint_dir=args.checkpoint_dir,
precision=args.precision,
tokenizer_id=args.tokenizer_hub_id,
output_path=args.output_path,
)
-141
View File
@@ -1,141 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F # noqa: N812
from packaging.version import Version
if Version(torch.__version__) > Version("2.5.0"):
# Ffex attention is only available from torch 2.5 onwards
from torch.nn.attention.flex_attention import (
_mask_mod_signature,
_round_up_to_multiple,
create_block_mask,
create_mask,
flex_attention,
)
# @torch.compile(dynamic=False)
def flex_attention_forward(
attention_mask: torch.Tensor,
batch_size: int,
head_dim: int,
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
scaling=None,
):
"""
This is defined out of classes to make compile happy.
"""
original_dtype = query_states.dtype
num_att_heads = 8
num_key_value_heads = 1
num_key_value_groups = num_att_heads // num_key_value_heads
key_states = key_states[:, :, :, None, :]
key_states = key_states.expand(
batch_size, key_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
)
key_states = key_states.reshape(
batch_size, key_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
)
value_states = value_states[:, :, :, None, :]
value_states = value_states.expand(
batch_size, value_states.shape[1], num_key_value_heads, num_key_value_groups, head_dim
)
value_states = value_states.reshape(
batch_size, value_states.shape[1], num_key_value_heads * num_key_value_groups, head_dim
)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
query_states = query_states.to(torch.float32)
key_states = key_states.to(torch.float32)
value_states = value_states.to(torch.float32)
causal_mask = attention_mask
if causal_mask is not None:
causal_mask = causal_mask[:, None, :, : key_states.shape[2]]
if causal_mask.shape[1] == 1 and query_states.shape[1] > 1:
causal_mask = causal_mask.expand(-1, query_states.shape[1], -1, -1)
def precomputed_mask_factory(precomputed_mask: torch.Tensor) -> _mask_mod_signature:
def mask_mod(b, h, q_idx, kv_idx):
# Danger zone: if b,h,q_idx,kv_idx exceed the shape, device-side assert occurs.
return precomputed_mask[b][h][q_idx][kv_idx]
return mask_mod
b_mask, h_mask, q_len, kv_len = causal_mask.shape # The shape of your mask
block_size = 128
q_len_rounded = _round_up_to_multiple(q_len, block_size)
kv_len_rounded = _round_up_to_multiple(kv_len, block_size)
# *CRITICAL* we do need to expand here, else we get a CUDA index error
pad_q = q_len_rounded - q_len
pad_k = kv_len_rounded - kv_len
padded_causal_mask = F.pad(causal_mask, (0, pad_k, 0, pad_q), value=0.0)
mask_mod_fn_orig = precomputed_mask_factory(padded_causal_mask)
mask_4d = create_mask(
mod_fn=mask_mod_fn_orig,
B=b_mask,
H=h_mask,
Q_LEN=q_len_rounded,
KV_LEN=kv_len_rounded,
device=causal_mask.device,
_compile=False,
)
mask_mod_fn_padded = precomputed_mask_factory(mask_4d)
block_mask = create_block_mask(
mask_mod=mask_mod_fn_padded,
B=b_mask,
H=h_mask,
Q_LEN=q_len_rounded,
KV_LEN=kv_len_rounded,
BLOCK_SIZE=block_size,
device=causal_mask.device,
_compile=False,
)
# mask is applied inside the kernel, ideally more efficiently than score_mod.
attn_output, attention_weights = flex_attention(
query_states,
key_states,
value_states,
block_mask=block_mask,
enable_gqa=True, # because we shaped query/key states for GQA
scale=head_dim**-0.5 if scaling is None else scaling,
return_lse=True,
)
attn_output = attn_output.to(dtype=original_dtype)
attn_output = attn_output.transpose(1, 2).contiguous() # [B, Q_LEN, H, head_dim]
attn_output = attn_output.reshape(
batch_size,
-1,
attn_output.shape[2] * attn_output.shape[3], # merges [H, head_dim]
)
return attn_output
-705
View File
@@ -1,705 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
π0: A Vision-Language-Action Flow Model for General Robot Control
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
[Jax code](https://github.com/Physical-Intelligence/openpi)
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
Disclaimer: It is not expected to perform as well as the original implementation.
Install pi0 extra dependencies:
```bash
pip install -e ".[pi0]"
```
Example of finetuning the pi0 pretrained model (`pi0_base` in `openpi`):
```bash
lerobot-train \
--policy.path=lerobot/pi0 \
--dataset.repo_id=danaaubakirova/koch_test
```
Example of finetuning the pi0 neural network with PaliGemma and expert Gemma
pretrained with VLM default parameters before pi0 finetuning:
```bash
lerobot-train \
--policy.type=pi0 \
--dataset.repo_id=danaaubakirova/koch_test
```
Example of using the pi0 pretrained model outside LeRobot training framework:
```python
policy = Pi0Policy.from_pretrained("lerobot/pi0")
```
"""
import math
from collections import deque
import torch
import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn
from lerobot.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS, OBS_STATE
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0.paligemma_with_expert import (
PaliGemmaWithExpertConfig,
PaliGemmaWithExpertModel,
)
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.utils import get_safe_dtype
def create_sinusoidal_pos_embedding(
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
) -> Tensor:
"""Computes sine-cosine positional embedding vectors for scalar positions."""
if dimension % 2 != 0:
raise ValueError(f"dimension ({dimension}) must be divisible by 2")
if time.ndim != 1:
raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
dtype = get_safe_dtype(torch.float64, device.type)
fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
period = min_period * (max_period / min_period) ** fraction
# Compute the outer product
scaling_factor = 1.0 / period * 2 * math.pi
sin_input = scaling_factor[None, :] * time[:, None]
pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
return pos_emb
def make_att_2d_masks(pad_masks, att_masks):
"""Copied from big_vision.
Tokens can attend to valid inputs tokens which have a cumulative mask_ar
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
setup several types of attention, for example:
[[1 1 1 1 1 1]]: pure causal attention.
[[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
themselves and the last 3 tokens have a causal attention. The first
entry could also be a 1 without changing behaviour.
[[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
block can attend all previous blocks and all tokens on the same block.
Args:
input_mask: bool[B, N] true if its part of the input, false if padding.
mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
it and 0 where it shares the same attention mask as the previous token.
"""
if att_masks.ndim != 2:
raise ValueError(att_masks.ndim)
if pad_masks.ndim != 2:
raise ValueError(pad_masks.ndim)
cumsum = torch.cumsum(att_masks, dim=1)
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
att_2d_masks = att_2d_masks & pad_2d_masks
return att_2d_masks
def resize_with_pad(img, width, height, pad_value=-1):
# assume no-op when width height fits already
if img.ndim != 4:
raise ValueError(f"(b,c,h,w) expected, but {img.shape}")
cur_height, cur_width = img.shape[2:]
ratio = max(cur_width / width, cur_height / height)
resized_height = int(cur_height / ratio)
resized_width = int(cur_width / ratio)
resized_img = F.interpolate(
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
)
pad_height = max(0, int(height - resized_height))
pad_width = max(0, int(width - resized_width))
# pad on left and top of image
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
return padded_img
def pad_vector(vector, new_dim):
"""Can be (batch_size x sequence_length x features_dimension)
or (batch_size x features_dimension)
"""
if vector.shape[-1] == new_dim:
return vector
shape = list(vector.shape)
current_dim = shape[-1]
shape[-1] = new_dim
new_vector = torch.zeros(*shape, dtype=vector.dtype, device=vector.device)
new_vector[..., :current_dim] = vector
return new_vector
def normalize(x, min_val, max_val):
return (x - min_val) / (max_val - min_val)
def unnormalize(x, min_val, max_val):
return x * (max_val - min_val) + min_val
def safe_arcsin(value):
# This ensures that the input stays within
# [1,1] to avoid invalid values for arcsin
return torch.arcsin(torch.clamp(value, -1.0, 1.0))
def aloha_gripper_to_angular(value):
# Aloha transforms the gripper positions into a linear space. The following code
# reverses this transformation to be consistent with pi0 which is pretrained in
# angular space.
#
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_POSITION_OPEN, PUPPET_GRIPPER_POSITION_CLOSED
value = unnormalize(value, min_val=0.01844, max_val=0.05800)
# This is the inverse of the angular to linear transformation inside the Interbotix code.
def linear_to_radian(linear_position, arm_length, horn_radius):
value = (horn_radius**2 + linear_position**2 - arm_length**2) / (2 * horn_radius * linear_position)
return safe_arcsin(value)
# The constants are taken from the Interbotix code.
value = linear_to_radian(value, arm_length=0.036, horn_radius=0.022)
# Normalize to [0, 1].
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
return normalize(value, min_val=0.4, max_val=1.5)
def aloha_gripper_from_angular(value):
# Convert from the gripper position used by pi0 to the gripper position that is used by Aloha.
# Note that the units are still angular but the range is different.
# The values 0.4 and 1.5 were measured on an actual Trossen robot.
value = unnormalize(value, min_val=0.4, max_val=1.5)
# These values are coming from the Aloha code:
# PUPPET_GRIPPER_JOINT_OPEN, PUPPET_GRIPPER_JOINT_CLOSE
return normalize(value, min_val=-0.6213, max_val=1.4910)
def aloha_gripper_from_angular_inv(value):
# Directly inverts the gripper_from_angular function.
value = unnormalize(value, min_val=-0.6213, max_val=1.4910)
return normalize(value, min_val=0.4, max_val=1.5)
class PI0Policy(PreTrainedPolicy):
"""Wrapper class around PI0FlowMatching model to train and run inference within LeRobot."""
config_class = PI0Config
name = "pi0"
def __init__(
self,
config: PI0Config,
):
"""
Args:
config: Policy configuration class instance or None, in which case the default instantiation of
the configuration class is used.
"""
super().__init__(config)
config.validate_features()
self.config = config
self.model = PI0FlowMatching(config)
self.reset()
def reset(self):
"""This should be called whenever the environment is reset."""
self._action_queue = deque([], maxlen=self.config.n_action_steps)
def get_optim_params(self) -> dict:
return self.parameters()
@classmethod
def from_pretrained(cls, *args, **kwargs):
"""Override the from_pretrained method to display important disclaimer."""
print(
"⚠️ DISCLAIMER: The PI0 model is ported from JAX by the Hugging Face team. \n"
" It is not expected to perform as well as the original implementation. \n"
" Original implementation: https://github.com/Physical-Intelligence/openpi"
)
return super().from_pretrained(*args, **kwargs)
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
"""Predict a chunk of actions given environment observations."""
raise NotImplementedError("Currently not implemented for PI0")
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Select a single action given environment observations.
This method wraps `select_actions` in order to return one action at a time for execution in the
environment. It works by managing the actions in a queue and only calling `select_actions` when the
queue is empty.
"""
self.eval()
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
# querying the policy.
if len(self._action_queue) == 0:
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
actions = self.model.sample_actions(
images, img_masks, lang_tokens, lang_masks, state, noise=noise
)
# Unpad actions
original_action_dim = self.config.action_feature.shape[0]
actions = actions[:, :, :original_action_dim]
if self.config.adapt_to_pi_aloha:
actions = self._pi_aloha_encode_actions(actions)
# `self.model.forward` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
# effectively has shape (n_action_steps, batch_size, *), hence the transpose.
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> tuple[Tensor, dict[str, Tensor]]:
"""Do a full training forward pass to compute the loss"""
if self.config.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
actions = self.prepare_action(batch)
actions_is_pad = batch.get("action_is_pad")
loss_dict = {}
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
loss_dict["losses_after_forward"] = losses.clone()
if actions_is_pad is not None:
in_episode_bound = ~actions_is_pad
losses = losses * in_episode_bound.unsqueeze(-1)
loss_dict["losses_after_in_ep_bound"] = losses.clone()
# Remove padding
losses = losses[:, :, : self.config.max_action_dim]
loss_dict["losses_after_rm_padding"] = losses.clone()
# For backward pass
loss = losses.mean()
# For logging
loss_dict["l2_loss"] = loss.item()
return loss, loss_dict
def prepare_images(self, batch):
"""Apply Pi0 preprocessing to the images, like resizing to 224x224 and padding to keep aspect ratio, and
convert pixel range from [0.0, 1.0] to [-1.0, 1.0] as requested by SigLIP.
"""
images = []
img_masks = []
present_img_keys = [key for key in self.config.image_features if key in batch]
missing_img_keys = [key for key in self.config.image_features if key not in batch]
if len(present_img_keys) == 0:
raise ValueError(
f"All image features are missing from the batch. At least one expected. (batch: {batch.keys()}) (image_features:{self.config.image_features})"
)
# Preprocess image features present in the batch
for key in present_img_keys:
img = batch[key]
if self.config.resize_imgs_with_padding is not None:
img = resize_with_pad(img, *self.config.resize_imgs_with_padding, pad_value=0)
# Normalize from range [0,1] to [-1,1] as expected by siglip
img = img * 2.0 - 1.0
bsize = img.shape[0]
device = img.device
mask = torch.ones(bsize, dtype=torch.bool, device=device)
images.append(img)
img_masks.append(mask)
# Create image features not present in the batch
# as fully 0 padded images.
for num_empty_cameras in range(len(missing_img_keys)):
if num_empty_cameras >= self.config.empty_cameras:
break
img = torch.ones_like(img) * -1
mask = torch.zeros_like(mask)
images.append(img)
img_masks.append(mask)
return images, img_masks
def _pi_aloha_decode_state(self, state):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
state[:, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
state[:, motor_idx] = aloha_gripper_to_angular(state[:, motor_idx])
return state
def _pi_aloha_encode_actions(self, actions):
# Flip the joints.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular(actions[:, :, motor_idx])
return actions
def _pi_aloha_encode_actions_inv(self, actions):
# Flip the joints again.
for motor_idx in [1, 2, 8, 9]:
actions[:, :, motor_idx] *= -1
# Reverse the gripper transformation that is being applied by the Aloha runtime.
for motor_idx in [6, 13]:
actions[:, :, motor_idx] = aloha_gripper_from_angular_inv(actions[:, :, motor_idx])
return actions
def prepare_state(self, batch):
"""Pad state"""
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
return state
def prepare_action(self, batch):
"""Pad action"""
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
return actions
class PI0FlowMatching(nn.Module):
"""
π0: A Vision-Language-Action Flow Model for General Robot Control
[Paper](https://www.physicalintelligence.company/download/pi0.pdf)
[Jax code](https://github.com/Physical-Intelligence/openpi)
Designed by Physical Intelligence. Ported from Jax by Hugging Face.
┌──────────────────────────────┐
│ actions │
│ ▲ │
│ ┌┴─────┐ │
│ kv cache │Gemma │ │
│ ┌──────────►│Expert│ │
│ │ │ │ │
│ ┌┴────────┐ │x 10 │ │
│ │ │ └▲──▲──┘ │
│ │PaliGemma│ │ │ │
│ │ │ │ robot state │
│ │ │ noise │
│ └▲──▲─────┘ │
│ │ │ │
│ │ image(s) │
│ language tokens │
└──────────────────────────────┘
"""
def __init__(self, config: PI0Config):
super().__init__()
self.config = config
paligemma_with_export_config = PaliGemmaWithExpertConfig(
freeze_vision_encoder=self.config.freeze_vision_encoder,
train_expert_only=self.config.train_expert_only,
attention_implementation=self.config.attention_implementation,
)
self.paligemma_with_expert = PaliGemmaWithExpertModel(paligemma_with_export_config)
# Projections are float32
self.state_proj = nn.Linear(self.config.max_state_dim, self.config.proj_width)
self.action_in_proj = nn.Linear(self.config.max_action_dim, self.config.proj_width)
self.action_out_proj = nn.Linear(self.config.proj_width, self.config.max_action_dim)
self.action_time_mlp_in = nn.Linear(self.config.proj_width * 2, self.config.proj_width)
self.action_time_mlp_out = nn.Linear(self.config.proj_width, self.config.proj_width)
self.set_requires_grad()
def set_requires_grad(self):
for params in self.state_proj.parameters():
params.requires_grad = self.config.train_state_proj
def sample_noise(self, shape, device):
noise = torch.normal(
mean=0.0,
std=1.0,
size=shape,
dtype=torch.float32,
device=device,
)
return noise
def sample_time(self, bsize, device):
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32)
time = time_beta * 0.999 + 0.001
return time
def embed_prefix(
self, images, img_masks, lang_tokens, lang_masks
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Embed images with SigLIP and language tokens with embedding layer to prepare
for PaliGemma transformer processing.
"""
# TODO: avoid list in python and torch.cat ; prefer pre-allocation with torch.empty
embs = []
pad_masks = []
att_masks = []
# TODO: remove for loop
for (
img,
img_mask,
) in zip(images, img_masks, strict=False):
img_emb = self.paligemma_with_expert.embed_image(img)
img_emb = img_emb.to(dtype=torch.bfloat16)
# Normalize image embeddings
img_emb_dim = img_emb.shape[-1]
img_emb = img_emb * torch.tensor(img_emb_dim**0.5, dtype=img_emb.dtype, device=img_emb.device)
bsize, num_img_embs = img_emb.shape[:2]
img_mask = img_mask[:, None].expand(bsize, num_img_embs)
embs.append(img_emb)
pad_masks.append(img_mask)
# Create attention masks so that image tokens attend to each other
att_masks += [0] * num_img_embs
lang_emb = self.paligemma_with_expert.embed_language_tokens(lang_tokens)
# Normalize language embeddings
lang_emb_dim = lang_emb.shape[-1]
lang_emb = lang_emb * math.sqrt(lang_emb_dim)
embs.append(lang_emb)
pad_masks.append(lang_masks)
# full attention between image and language inputs
num_lang_embs = lang_emb.shape[1]
att_masks += [0] * num_lang_embs
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=torch.bool, device=pad_masks.device)
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks
def embed_suffix(self, state, noisy_actions, timestep):
"""Embed state, noisy_actions, timestep to prepare for Expert Gemma processing."""
embs = []
pad_masks = []
att_masks = []
# Embed state
state_emb = self.state_proj(state)
state_emb = state_emb.to(dtype=torch.bfloat16)
embs.append(state_emb[:, None, :])
bsize = state_emb.shape[0]
dtype = state_emb.dtype
device = state_emb.device
state_mask = torch.ones(bsize, 1, dtype=torch.bool, device=device)
pad_masks.append(state_mask)
# Set attention masks so that image and language inputs do not attend to state or actions
att_masks += [1]
# Embed timestep using sine-cosine positional encoding with sensitivity in the range [0, 1]
time_emb = create_sinusoidal_pos_embedding(
timestep, self.config.proj_width, min_period=4e-3, max_period=4.0, device=device
)
time_emb = time_emb.type(dtype=dtype)
# Fuse timestep + action information using an MLP
action_emb = self.action_in_proj(noisy_actions)
time_emb = time_emb[:, None, :].expand_as(action_emb)
action_time_emb = torch.cat([action_emb, time_emb], dim=2)
action_time_emb = self.action_time_mlp_in(action_time_emb)
action_time_emb = F.silu(action_time_emb) # swish == silu
action_time_emb = self.action_time_mlp_out(action_time_emb)
# Add to input tokens
embs.append(action_time_emb)
bsize, action_time_dim = action_time_emb.shape[:2]
action_time_mask = torch.ones(bsize, action_time_dim, dtype=torch.bool, device=device)
pad_masks.append(action_time_mask)
# Set attention masks so that image, language and state inputs do not attend to action tokens
att_masks += [1] + ([0] * (self.config.n_action_steps - 1))
embs = torch.cat(embs, dim=1)
pad_masks = torch.cat(pad_masks, dim=1)
att_masks = torch.tensor(att_masks, dtype=embs.dtype, device=embs.device)
att_masks = att_masks[None, :].expand(bsize, len(att_masks))
return embs, pad_masks, att_masks
def forward(
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
) -> Tensor:
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
if noise is None:
noise = self.sample_noise(actions.shape, actions.device)
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, time)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
(_, suffix_out), _ = self.paligemma_with_expert.forward(
attention_mask=att_2d_masks,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
fill_kv_cache=False,
)
suffix_out = suffix_out[:, -self.config.n_action_steps :]
# Original openpi code, upcast attention output
suffix_out = suffix_out.to(dtype=torch.float32)
v_t = self.action_out_proj(suffix_out)
losses = F.mse_loss(u_t, v_t, reduction="none")
return losses
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0]
device = state.device
if noise is None:
actions_shape = (bsize, self.config.n_action_steps, self.config.max_action_dim)
noise = self.sample_noise(actions_shape, device)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
prefix_att_2d_masks = make_att_2d_masks(prefix_pad_masks, prefix_att_masks)
prefix_position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1
# Compute image and language key value cache
_, past_key_values = self.paligemma_with_expert.forward(
attention_mask=prefix_att_2d_masks,
position_ids=prefix_position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=self.config.use_cache,
fill_kv_cache=True,
)
dt = -1.0 / self.config.num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
state,
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
# Euler step
x_t += dt * v_t
time += dt
return x_t
def denoise_step(
self,
state,
prefix_pad_masks,
past_key_values,
x_t,
timestep,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(state, x_t, timestep)
suffix_len = suffix_pad_masks.shape[1]
batch_size = prefix_pad_masks.shape[0]
prefix_len = prefix_pad_masks.shape[1]
prefix_pad_2d_masks = prefix_pad_masks[:, None, :].expand(batch_size, suffix_len, prefix_len)
suffix_att_2d_masks = make_att_2d_masks(suffix_pad_masks, suffix_att_masks)
full_att_2d_masks = torch.cat([prefix_pad_2d_masks, suffix_att_2d_masks], dim=2)
prefix_offsets = torch.sum(prefix_pad_masks, dim=-1)[:, None]
position_ids = prefix_offsets + torch.cumsum(suffix_pad_masks, dim=1) - 1
outputs_embeds, _ = self.paligemma_with_expert.forward(
attention_mask=full_att_2d_masks,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=[None, suffix_embs],
use_cache=self.config.use_cache,
fill_kv_cache=False,
)
suffix_out = outputs_embeds[1]
suffix_out = suffix_out[:, -self.config.n_action_steps :]
suffix_out = suffix_out.to(dtype=torch.float32)
v_t = self.action_out_proj(suffix_out)
return v_t
@@ -33,7 +33,7 @@ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditi
from lerobot.configs.policies import PreTrainedConfig
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pi0_openpi.configuration_pi0openpi import PI0OpenPIConfig
from lerobot.policies.pi0.configuration_pi0openpi import PI0OpenPIConfig
from lerobot.policies.pretrained import PreTrainedPolicy, T
@@ -525,7 +525,7 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
msg = """transformers_replace is not installed correctly.
Please install it with `pip install transformers==4.53.2`
and `cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* \
and `cp -r ./src/lerobot/policies/pi0/transformers_replace/* \
$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")`"""
try:
@@ -846,7 +846,7 @@ class PI0OpenPIPolicy(PreTrainedPolicy):
"""PI0 OpenPI Policy for LeRobot."""
config_class = PI0OpenPIConfig
name = "pi0_openpi"
name = "pi0"
def __init__( # see lerobot pi0 `__init__`
self,
@@ -1,420 +0,0 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.version
from pytest import Cache
from torch import nn
from transformers import (
AutoConfig,
GemmaForCausalLM,
PaliGemmaForConditionalGeneration,
PretrainedConfig,
PreTrainedModel,
)
from transformers.models.auto import CONFIG_MAPPING
from lerobot.policies.pi0.flex_attention import flex_attention_forward
def apply_rope(x, positions, max_wavelength=10_000):
"""
Applies RoPE positions [B, L] to x [B, L, H, D].
"""
d_half = x.shape[-1] // 2
device = x.device
dtype = x.dtype
x = x.to(torch.float32)
freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
timescale = max_wavelength**freq_exponents
radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
radians = radians[..., None, :]
sin = torch.sin(radians) # .to(dtype=dtype)
cos = torch.cos(radians) # .to(dtype=dtype)
x1, x2 = x.split(d_half, dim=-1)
res = torch.empty_like(x)
res[..., :d_half] = x1 * cos - x2 * sin
res[..., d_half:] = x2 * cos + x1 * sin
return res.to(dtype)
class PaliGemmaWithExpertConfig(PretrainedConfig):
model_type = "PaliGemmaWithExpertModel"
sub_configs = {"paligemma_config": AutoConfig, "gemma_expert_config": AutoConfig}
def __init__(
self,
paligemma_config: dict | None = None,
gemma_expert_config: dict | None = None,
freeze_vision_encoder: bool = True,
train_expert_only: bool = True,
attention_implementation: str = "eager",
**kwargs,
):
self.freeze_vision_encoder = freeze_vision_encoder
self.train_expert_only = train_expert_only
self.attention_implementation = attention_implementation
if paligemma_config is None:
# Default config from Pi0
self.paligemma_config = CONFIG_MAPPING["paligemma"](
transformers_version="4.48.1",
_vocab_size=257152,
bos_token_id=2,
eos_token_id=1,
hidden_size=2048,
image_token_index=257152,
model_type="paligemma",
pad_token_id=0,
projection_dim=2048,
text_config={
"hidden_activation": "gelu_pytorch_tanh",
"hidden_size": 2048,
"intermediate_size": 16384,
"model_type": "gemma",
"num_attention_heads": 8,
"num_hidden_layers": 18,
"num_image_tokens": 256,
"num_key_value_heads": 1,
"torch_dtype": "float32",
"vocab_size": 257152,
},
vision_config={
"hidden_size": 1152,
"intermediate_size": 4304,
"model_type": "siglip_vision_model",
"num_attention_heads": 16,
"num_hidden_layers": 27,
"num_image_tokens": 256,
"patch_size": 14,
"projection_dim": 2048,
"projector_hidden_act": "gelu_fast",
"torch_dtype": "float32",
"vision_use_head": False,
},
)
elif isinstance(self.paligemma_config, dict):
# Override Pi0 default config for PaliGemma
if "model_type" not in gemma_expert_config:
paligemma_config["model_type"] = "paligemma"
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
self.paligemma_config = cfg_cls(**paligemma_config)
if gemma_expert_config is None:
# Default config from Pi0
self.gemma_expert_config = CONFIG_MAPPING["gemma"](
attention_bias=False,
attention_dropout=0.0,
bos_token_id=2,
eos_token_id=1,
head_dim=256,
hidden_act="gelu_pytorch_tanh",
hidden_activation="gelu_pytorch_tanh",
hidden_size=1024,
initializer_range=0.02,
intermediate_size=4096,
max_position_embeddings=8192,
model_type="gemma",
num_attention_heads=8,
num_hidden_layers=18,
num_key_value_heads=1,
pad_token_id=0,
rms_norm_eps=1e-06,
rope_theta=10000.0,
torch_dtype="float32",
transformers_version="4.48.1",
use_cache=True,
vocab_size=257152,
)
elif isinstance(self.gemma_expert_config, dict):
# Override Pi0 default config for Gemma Expert
if "model_type" not in gemma_expert_config:
gemma_expert_config["model_type"] = "gemma"
cfg_cls = CONFIG_MAPPING[paligemma_config["model_type"]]
self.gemma_expert_config = cfg_cls(**gemma_expert_config)
super().__init__(**kwargs)
def __post_init__(self):
super().__post_init__()
if self.train_expert_only and not self.freeze_vision_encoder:
raise ValueError(
"You set `freeze_vision_encoder=False` and `train_expert_only=True` which are not compatible."
)
if self.attention_implementation not in ["eager", "fa2", "flex"]:
raise ValueError(
f"Wrong value provided for `attention_implementation` ({self.attention_implementation}). Expected 'eager', 'fa2' or 'flex'."
)
class PaliGemmaWithExpertModel(PreTrainedModel):
config_class = PaliGemmaWithExpertConfig
def __init__(self, config: PaliGemmaWithExpertConfig):
super().__init__(config=config)
self.config = config
self.paligemma = PaliGemmaForConditionalGeneration(config=config.paligemma_config)
self.gemma_expert = GemmaForCausalLM(config=config.gemma_expert_config)
# Remove unused embed_tokens
self.gemma_expert.model.embed_tokens = None
self.to_bfloat16_like_physical_intelligence()
self.set_requires_grad()
def set_requires_grad(self):
if self.config.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
for params in self.paligemma.vision_tower.parameters():
params.requires_grad = False
if self.config.train_expert_only:
self.paligemma.eval()
for params in self.paligemma.parameters():
params.requires_grad = False
def train(self, mode: bool = True):
super().train(mode)
if self.config.freeze_vision_encoder:
self.paligemma.vision_tower.eval()
if self.config.train_expert_only:
self.paligemma.eval()
def to_bfloat16_like_physical_intelligence(self):
self.paligemma = self.paligemma.to(dtype=torch.bfloat16)
params_to_change_dtype = [
"language_model.model.layers",
"gemma_expert.model.layers",
"vision_tower",
"multi_modal",
]
for name, param in self.named_parameters():
if any(selector in name for selector in params_to_change_dtype):
param.data = param.data.to(dtype=torch.bfloat16)
def embed_image(self, image: torch.Tensor):
# Handle different transformers versions
if hasattr(self.paligemma, "get_image_features"):
return self.paligemma.get_image_features(image)
else:
return self.paligemma.model.get_image_features(image)
def embed_language_tokens(self, tokens: torch.Tensor):
return self.paligemma.language_model.embed_tokens(tokens)
# TODO: break down this huge forward into modules or functions
def forward(
self,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: list[torch.FloatTensor] | Cache | None = None,
inputs_embeds: list[torch.FloatTensor] = None,
use_cache: bool | None = None,
fill_kv_cache: bool | None = None,
):
models = [self.paligemma.language_model, self.gemma_expert.model]
for hidden_states in inputs_embeds:
# TODO this is very inefficient
# dtype is always the same, batch size too (if > 1 len)
# device could be trickier in multi gpu edge cases but that's it
if hidden_states is None:
continue
batch_size = hidden_states.shape[0]
# RMSNorm
num_layers = self.paligemma.config.text_config.num_hidden_layers
head_dim = self.paligemma.config.text_config.head_dim
for layer_idx in range(num_layers):
query_states = []
key_states = []
value_states = []
for i, hidden_states in enumerate(inputs_embeds):
if hidden_states is None:
continue
layer = models[i].layers[layer_idx]
# normalizer = torch.tensor(models[i].config.hidden_size**0.5, dtype=hidden_states.dtype)
# hidden_states = hidden_states * normalizer
hidden_states = layer.input_layernorm(hidden_states)
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
hidden_states = hidden_states.to(dtype=torch.bfloat16)
query_state = layer.self_attn.q_proj(hidden_states).view(hidden_shape)
key_state = layer.self_attn.k_proj(hidden_states).view(hidden_shape)
value_state = layer.self_attn.v_proj(hidden_states).view(hidden_shape)
query_states.append(query_state)
key_states.append(key_state)
value_states.append(value_state)
# B,L,H,D with L sequence length, H number of heads, D head dim
# concatenate on the number of embeddings/tokens
query_states = torch.cat(query_states, dim=1)
key_states = torch.cat(key_states, dim=1)
value_states = torch.cat(value_states, dim=1)
query_states = apply_rope(query_states, position_ids)
key_states = apply_rope(key_states, position_ids)
if use_cache and past_key_values is None:
past_key_values = {}
if use_cache:
if fill_kv_cache:
past_key_values[layer_idx] = {
"key_states": key_states,
"value_states": value_states,
}
else:
# TODO here, some optimization can be done - similar to a `StaticCache` we can declare the `max_len` before.
# so we create an empty cache, with just one cuda malloc, and if (in autoregressive case) we reach
# the max len, then we (for instance) double the cache size. This implementation already exists
# in `transformers`. (molbap)
key_states = torch.cat([past_key_values[layer_idx]["key_states"], key_states], dim=1)
value_states = torch.cat(
[past_key_values[layer_idx]["value_states"], value_states], dim=1
)
attention_interface = self.get_attention_interface()
att_output = attention_interface(
attention_mask, batch_size, head_dim, query_states, key_states, value_states
)
att_output = att_output.to(dtype=torch.bfloat16)
# first part of att_output is prefix (up to sequence length, [:, 0:prefix_seq_len])
outputs_embeds = []
start = 0
for i, hidden_states in enumerate(inputs_embeds):
layer = models[i].layers[layer_idx]
if hidden_states is not None:
end = start + hidden_states.shape[1]
if att_output.dtype != layer.self_attn.o_proj.weight.dtype:
att_output = att_output.to(layer.self_attn.o_proj.weight.dtype)
out_emb = layer.self_attn.o_proj(att_output[:, start:end])
# TODO: first dropout (by default 0.0)
# first residual
out_emb += hidden_states
after_first_residual = out_emb.clone()
out_emb = layer.post_attention_layernorm(out_emb)
out_emb = layer.mlp(out_emb)
# TODO: second dropout (by default 0.0)
# second residual
out_emb += after_first_residual
outputs_embeds.append(out_emb)
start = end
else:
outputs_embeds.append(None)
inputs_embeds = outputs_embeds
# final norm
outputs_embeds = []
for i, hidden_states in enumerate(inputs_embeds):
if hidden_states is not None:
out_emb = models[i].norm(hidden_states)
outputs_embeds.append(out_emb)
else:
outputs_embeds.append(None)
return outputs_embeds, past_key_values
def get_attention_interface(self):
if self.config.attention_implementation == "fa2":
attention_interface = self.flash_attention_forward
elif self.config.attention_implementation == "flex":
attention_interface = flex_attention_forward
else:
attention_interface = self.eager_attention_forward
return attention_interface
def flash_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
):
raise NotImplementedError("FA2 is not implemented (yet)")
def eager_attention_forward(
self, attention_mask, batch_size, head_dim, query_states, key_states, value_states
):
num_att_heads = self.config.paligemma_config.text_config.num_attention_heads
num_key_value_heads = self.config.paligemma_config.text_config.num_key_value_heads
num_key_value_groups = num_att_heads // num_key_value_heads
# query_states: batch_size, sequence_length, num_att_head, head_dim
# key_states: batch_size, sequence_length, num_key_value_head, head_dim
# value_states: batch_size, sequence_length, num_key_value_head, head_dim
sequence_length = key_states.shape[1]
key_states = key_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
key_states = key_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
value_states = value_states[:, :, :, None, :].expand(
batch_size, sequence_length, num_key_value_heads, num_key_value_groups, head_dim
)
value_states = value_states.reshape(
batch_size, sequence_length, num_key_value_heads * num_key_value_groups, head_dim
)
# Attention here is upcasted to float32 to match the original eager implementation.
query_states = query_states.to(dtype=torch.float32)
key_states = key_states.to(dtype=torch.float32)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
att_weights = torch.matmul(query_states, key_states.transpose(2, 3))
att_weights *= head_dim**-0.5
big_neg = -2.3819763e38 # See gemma/modules.py
masked_att_weights = torch.where(attention_mask[:, None, :, :], att_weights, big_neg)
probs = nn.functional.softmax(masked_att_weights, dim=-1)
probs = probs.to(dtype=value_states.dtype)
# probs: batch_size, num_key_value_head, num_att_head, sequence_length, sequence_length
# value_states: batch_size, sequence_length, num_att_heads, head_dim
att_output = torch.matmul(probs, value_states.permute(0, 2, 1, 3))
att_output = att_output.permute(0, 2, 1, 3)
# we use -1 because sequence length can change
att_output = att_output.reshape(batch_size, -1, num_key_value_heads * num_key_value_groups * head_dim)
return att_output
-166
View File
@@ -1,166 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
PolicyProcessorPipeline,
ProcessorStep,
ProcessorStepRegistry,
RenameObservationsProcessorStep,
TokenizerProcessorStep,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
@ProcessorStepRegistry.register(name="pi0_new_line_processor")
class Pi0NewLineProcessor(ComplementaryDataProcessorStep):
"""
Ensures that the task description string ends with a newline character.
This processing step is required for compatibility with the PaliGemma tokenizer,
which expects a newline at the end of the text prompt. It handles both single
strings and lists of strings for the 'task' key in complementary data.
"""
def complementary_data(self, complementary_data):
"""
Adds a newline to the 'task' field if it doesn't already have one.
Args:
complementary_data: A dictionary that may contain a 'task' key with a
string or list of strings.
Returns:
A new dictionary with the modified 'task' field.
"""
if "task" not in complementary_data:
return complementary_data
task = complementary_data["task"]
if task is None:
return complementary_data
new_complementary_data = dict(complementary_data)
# Handle both string and list of strings
if isinstance(task, str):
# Single string: add newline if not present
if not task.endswith("\n"):
new_complementary_data["task"] = f"{task}\n"
elif isinstance(task, list) and all(isinstance(t, str) for t in task):
# List of strings: add newline to each if not present
new_complementary_data["task"] = [t if t.endswith("\n") else f"{t}\n" for t in task]
# If task is neither string nor list of strings, leave unchanged
return new_complementary_data
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
"""
This step does not alter the feature definitions.
Args:
features: The input feature dictionary.
Returns:
The unchanged feature dictionary.
"""
return features
def make_pi0_pre_post_processors(
config: PI0Config,
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
) -> tuple[
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
PolicyProcessorPipeline[PolicyAction, PolicyAction],
]:
"""
Constructs pre-processor and post-processor pipelines for the PI0 policy.
The pre-processing pipeline prepares input data for the model by:
1. Renaming features to match pretrained configurations.
2. Normalizing input and output features based on dataset statistics.
3. Adding a batch dimension.
4. Appending a newline character to the task description for tokenizer compatibility.
5. Tokenizing the text prompt using the PaliGemma tokenizer.
6. Moving all data to the specified device.
The post-processing pipeline handles the model's output by:
1. Moving data to the CPU.
2. Unnormalizing the output features to their original scale.
Args:
config: The configuration object for the PI0 policy.
dataset_stats: A dictionary of statistics for normalization.
preprocessor_kwargs: Additional arguments for the pre-processor pipeline.
postprocessor_kwargs: Additional arguments for the post-processor pipeline.
Returns:
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma
TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224",
max_length=config.tokenizer_max_length,
padding_side="right",
padding="max_length",
),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
]
output_steps: list[ProcessorStep] = [
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
DeviceProcessorStep(device="cpu"),
]
return (
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
steps=input_steps,
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
),
PolicyProcessorPipeline[PolicyAction, PolicyAction](
steps=output_steps,
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
),
)
@@ -22,7 +22,7 @@ from lerobot.optim.optimizers import AdamWConfig
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
@PreTrainedConfig.register_subclass("pi05_openpi")
@PreTrainedConfig.register_subclass("pi05")
@dataclass
class PI05OpenPIConfig(PreTrainedConfig):
# Model architecture
@@ -34,7 +34,7 @@ from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditi
from lerobot.configs.policies import PreTrainedConfig
from lerobot.constants import ACTION, OBS_STATE
from lerobot.policies.normalize import Normalize, Unnormalize
from lerobot.policies.pi05_openpi.configuration_pi05openpi import PI05OpenPIConfig
from lerobot.policies.pi05.configuration_pi05openpi import PI05OpenPIConfig
from lerobot.policies.pretrained import PreTrainedPolicy, T
@@ -525,7 +525,7 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
msg = """transformers_replace is not installed correctly.
Please install it with `pip install transformers==4.53.2`
and `cp -r ./src/lerobot/policies/pi0_openpi/transformers_replace/* \
and `cp -r ./src/lerobot/policies/pi0/transformers_replace/* \
$(python -c "import transformers, os; print(os.path.dirname(transformers.__file__))")`"""
try:
@@ -820,7 +820,7 @@ class PI05OpenPIPolicy(PreTrainedPolicy):
"""PI05 OpenPI Policy for LeRobot."""
config_class = PI05OpenPIConfig
name = "pi05_openpi"
name = "pi05"
def __init__( # see lerobot pi0 `__init__`
self,
+1 -1
View File
@@ -23,7 +23,7 @@ DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
DEFAULT_OBS_QUEUE_TIMEOUT = 2
# All action chunking policies
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "pi0", "tdmpc", "vqbet", "pi0_openpi", "pi05_openpi"]
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
# TODO: Add all other robots
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower"]
-1
View File
@@ -26,7 +26,6 @@ from lerobot.constants import OBS_IMAGES, OBS_STATE
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
# NOTE: Configs need to be loaded for the client to be able to instantiate the policy config
from lerobot.policies import ACTConfig, DiffusionConfig, PI0Config, SmolVLAConfig, VQBeTConfig # noqa: F401
from lerobot.robots.robot import Robot
from lerobot.utils.utils import init_logging
@@ -19,11 +19,9 @@
[Diffusion Policy](https://huggingface.co/papers/2303.04137) treats visuomotor control as a generative diffusion process, producing smooth, multi-step action trajectories that excel at contact-rich manipulation.
{% elif model_name == "vqbet" %}
[VQ-BET](https://huggingface.co/papers/2403.03181) combines vector-quantised action tokens with Behaviour Transformers to discretise control and achieve data-efficient imitation across diverse skills.
{% elif model_name == "pi0" %}
[Pi0](https://huggingface.co/papers/2410.24164) is a generalist vision-language-action transformer that converts multimodal observations and text instructions into robot actions for zero-shot task transfer.
{% elif model_name == "pi0fast" %}
[Pi0-Fast](https://huggingface.co/papers/2501.09747) is a variant of Pi0 that uses a new tokenization method called FAST, which enables training of an autoregressive vision-language-action policy for high-frequency robotic tasks with improved performance and reduced training time.
{% elif model_name == "pi0_openpi" %}
{% elif model_name == "pi0" %}
**π₀ (Pi0)**
π₀ is a Vision-Language-Action model for general robot control, from Physical Intelligence. The LeRobot implementation is adapted from their open source OpenPI repository.
@@ -33,7 +31,7 @@
π₀ represents a breakthrough in robotics as the first general-purpose robot foundation model developed by Physical Intelligence. Unlike traditional robots that are narrow specialists programmed for repetitive motions, π₀ is designed to be a generalist policy that can understand visual inputs, interpret natural language instructions, and control a variety of different robots across diverse tasks.
For more details, see the [Physical Intelligence π₀ blog post](https://www.physicalintelligence.company/blog/pi0).
{% elif model_name == "pi05_openpi" %}
{% elif model_name == "pi05" %}
**π₀.₅ (Pi05) Policy**
π₀.₅ is a Vision-Language-Action model with open-world generalization, from Physical Intelligence. The LeRobot implementation is adapted from their open source OpenPI repository.
+5 -5
View File
@@ -13,7 +13,7 @@ pytestmark = pytest.mark.skipif(
reason="This test requires local OpenPI installation and is not meant for CI",
)
from lerobot.policies.pi05_openpi import PI05OpenPIConfig, PI05OpenPIPolicy # noqa: E402
from lerobot.policies.pi05 import PI05Config, PI05Policy # noqa: E402
from tests.utils import require_cuda # noqa: E402
@@ -22,7 +22,7 @@ def test_pi05_model_architecture():
"""Test that pi05=True creates the correct model architecture."""
# Create config
config = PI05OpenPIConfig(
config = PI05Config(
max_action_dim=7,
max_state_dim=14,
dtype="float32",
@@ -73,7 +73,7 @@ def test_pi05_model_architecture():
}
# Instantiate policy
policy = PI05OpenPIPolicy(config, dataset_stats)
policy = PI05Policy(config, dataset_stats)
# Verify pi05 model components exist
# Check that time_mlp layers exist (for AdaRMS conditioning)
@@ -104,7 +104,7 @@ def test_pi05_forward_pass():
"""Test forward pass with"""
# Create config
config = PI05OpenPIConfig(
config = PI05Config(
max_action_dim=7,
max_state_dim=14,
dtype="float32",
@@ -150,7 +150,7 @@ def test_pi05_forward_pass():
}
# Instantiate policy
policy = PI05OpenPIPolicy(config, dataset_stats)
policy = PI05Policy(config, dataset_stats)
# Create test batch
batch_size = 2
+2 -2
View File
@@ -14,7 +14,7 @@ pytestmark = pytest.mark.skipif(
)
from lerobot.policies.factory import make_policy_config # noqa: E402
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy # noqa: E402
from lerobot.policies.pi0 import PI0OpenPIConfig, PI0OpenPIPolicy # noqa: E402
from tests.utils import require_cuda # noqa: E402
@@ -96,7 +96,7 @@ def test_config_creation():
"""Test policy config creation through factory."""
try:
config = make_policy_config(
policy_type="pi0_openpi",
policy_type="pi0",
max_action_dim=7,
max_state_dim=14,
)
@@ -21,7 +21,7 @@ from openpi.models_pytorch import preprocessing_pytorch as openpi_preprocessing
from openpi.models_pytorch.pi0_pytorch import PI0Pytorch # noqa: E402
from transformers import AutoTokenizer # noqa: E402
from lerobot.policies.pi0_openpi import PI0OpenPIConfig, PI0OpenPIPolicy # noqa: E402
from lerobot.policies.pi0 import PI0Config, PI0Policy # noqa: E402
DUMMY_ACTION_DIM = 32
DUMMY_STATE_DIM = 32
@@ -68,9 +68,7 @@ class PI0BaseOriginalConfig:
def instantiate_lerobot_pi0(from_pretrained: bool = False):
if from_pretrained:
# Load the policy first
policy = PI0OpenPIPolicy.from_pretrained(
pretrained_name_or_path="pepijn223/pi0_base_fp32", strict=True
)
policy = PI0Policy.from_pretrained(pretrained_name_or_path="pepijn223/pi0_base_fp32", strict=True)
# Then reinitialize the normalization with proper stats
from lerobot.policies.normalize import Normalize, Unnormalize
@@ -84,10 +82,8 @@ def instantiate_lerobot_pi0(from_pretrained: bool = False):
policy.config.output_features, policy.config.normalization_mapping, DUMMY_DATASET_STATS
)
else:
config = PI0OpenPIConfig(
max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32"
)
policy = PI0OpenPIPolicy(config, DUMMY_DATASET_STATS)
config = PI0Config(max_action_dim=DUMMY_ACTION_DIM, max_state_dim=DUMMY_STATE_DIM, dtype="float32")
policy = PI0Policy(config, DUMMY_DATASET_STATS)
policy.to(DEVICE)
return policy
+9 -9
View File
@@ -18,8 +18,8 @@ pytestmark = pytest.mark.skipif(
reason="This test requires HuggingFace authentication and is not meant for CI",
)
from lerobot.policies.pi0_openpi import PI0OpenPIPolicy # noqa: E402
from lerobot.policies.pi05_openpi.modeling_pi05openpi import PI05OpenPIPolicy # noqa: E402
from lerobot.policies.pi0 import PI0Policy # noqa: E402
from lerobot.policies.pi05.modeling_pi05openpi import PI05Policy # noqa: E402
def create_dummy_stats(config):
@@ -48,13 +48,13 @@ def create_dummy_stats(config):
# Test data for all 6 base models
MODEL_TEST_PARAMS = [
# PI0 models
("pepijn223/pi0_base_fp32", "PI0", PI0OpenPIPolicy),
("pepijn223/pi0_droid_fp32", "PI0", PI0OpenPIPolicy),
("pepijn223/pi0_libero_fp32", "PI0", PI0OpenPIPolicy),
("pepijn223/pi0_base_fp32", "PI0", PI0Policy),
("pepijn223/pi0_droid_fp32", "PI0", PI0Policy),
("pepijn223/pi0_libero_fp32", "PI0", PI0Policy),
# PI0.5 models
("pepijn223/pi05_base_fp32", "PI0.5", PI05OpenPIPolicy),
("pepijn223/pi05_droid_fp32", "PI0.5", PI05OpenPIPolicy),
("pepijn223/pi05_libero_fp32", "PI0.5", PI05OpenPIPolicy),
("pepijn223/pi05_base_fp32", "PI0.5", PI05Policy),
("pepijn223/pi05_droid_fp32", "PI0.5", PI05Policy),
("pepijn223/pi05_libero_fp32", "PI0.5", PI05Policy),
]
@@ -65,7 +65,7 @@ def test_all_base_models_hub_loading(model_id, model_type, policy_class):
Args:
model_id: HuggingFace model ID (e.g., "pepijn223/pi0_base_fp32")
model_type: Model type ("PI0" or "PI0.5")
policy_class: Policy class to use (PI0OpenPIPolicy or PI05OpenPIPolicy)
policy_class: Policy class to use (PI0Policy or PI05Policy)
"""
print(f"\n{'=' * 80}")
print(f"Testing {model_type} model: {model_id}")
-424
View File
@@ -1,424 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for PI0 policy processor."""
from unittest.mock import patch
import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pi0.processor_pi0 import Pi0NewLineProcessor, make_pi0_pre_post_processors
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeviceProcessorStep,
EnvTransition,
NormalizerProcessorStep,
ProcessorStep,
RenameObservationsProcessorStep,
TransitionKey,
UnnormalizerProcessorStep,
)
from lerobot.processor.converters import create_transition, transition_to_batch
class MockTokenizerProcessorStep(ProcessorStep):
"""Mock tokenizer processor step for testing."""
def __init__(self, *args, **kwargs):
# Accept any arguments to mimic the real TokenizerProcessorStep interface
pass
def __call__(self, transition: EnvTransition) -> EnvTransition:
# Pass through transition unchanged
return transition
def transform_features(self, features):
# Pass through features unchanged
return features
def create_default_config():
"""Create a default PI0 configuration for testing."""
config = PI0Config()
config.input_features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)),
OBS_IMAGE: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
}
config.output_features = {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(6,)),
}
config.normalization_mapping = {
FeatureType.STATE: NormalizationMode.MEAN_STD,
FeatureType.VISUAL: NormalizationMode.IDENTITY,
FeatureType.ACTION: NormalizationMode.MIN_MAX,
}
config.device = "cpu"
config.tokenizer_max_length = 128
return config
def create_default_stats():
"""Create default dataset statistics for testing."""
return {
OBS_STATE: {"mean": torch.zeros(10), "std": torch.ones(10)},
OBS_IMAGE: {}, # No normalization for images
ACTION: {"min": torch.full((6,), -1.0), "max": torch.ones(6)},
}
def test_make_pi0_processor_basic():
"""Test basic creation of PI0 processor."""
config = create_default_config()
stats = create_default_stats()
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
stats,
)
# Check processor names
assert preprocessor.name == "policy_preprocessor"
assert postprocessor.name == "policy_postprocessor"
# Check steps in preprocessor
assert len(preprocessor.steps) == 6
assert isinstance(preprocessor.steps[0], RenameObservationsProcessorStep)
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
assert isinstance(preprocessor.steps[2], Pi0NewLineProcessor)
# Step 3 would be TokenizerProcessorStep but it's mocked
assert isinstance(preprocessor.steps[4], DeviceProcessorStep)
assert isinstance(preprocessor.steps[5], NormalizerProcessorStep)
# Check steps in postprocessor
assert len(postprocessor.steps) == 2
assert isinstance(postprocessor.steps[0], UnnormalizerProcessorStep)
assert isinstance(postprocessor.steps[1], DeviceProcessorStep)
def test_pi0_newline_processor_single_task():
"""Test Pi0NewLineProcessor with single task string."""
processor = Pi0NewLineProcessor()
# Test with task that doesn't have newline
transition = create_transition(complementary_data={"task": "test task"})
result = processor(transition)
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n"
# Test with task that already has newline
transition = create_transition(complementary_data={"task": "test task\n"})
result = processor(transition)
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == "test task\n"
def test_pi0_newline_processor_list_of_tasks():
"""Test Pi0NewLineProcessor with list of task strings."""
processor = Pi0NewLineProcessor()
# Test with list of tasks
tasks = ["task1", "task2\n", "task3"]
transition = create_transition(complementary_data={"task": tasks})
result = processor(transition)
expected = ["task1\n", "task2\n", "task3\n"]
assert result[TransitionKey.COMPLEMENTARY_DATA]["task"] == expected
def test_pi0_newline_processor_empty_transition():
"""Test Pi0NewLineProcessor with empty transition."""
processor = Pi0NewLineProcessor()
# Test with no complementary_data
transition = create_transition()
result = processor(transition)
assert result == transition
# Test with complementary_data but no task
transition = create_transition(complementary_data={"other": "data"})
result = processor(transition)
assert result == transition
# Test with None task
transition = create_transition(complementary_data={"task": None})
result = processor(transition)
assert result == transition
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_pi0_processor_cuda():
"""Test PI0 processor with CUDA device."""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs):
pass
def __call__(self, transition):
return transition
def state_dict(self):
return {}
def load_state_dict(self, state):
pass
def reset(self):
pass
def get_config(self):
return {"tokenizer_name": "google/paligemma-3b-pt-224"}
def transform_features(self, features):
return features
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
stats,
)
# Create CPU data
observation = {
OBS_STATE: torch.randn(10),
OBS_IMAGE: torch.randn(3, 224, 224),
}
action = torch.randn(6)
transition = create_transition(observation, action, complementary_data={"task": "test task"})
batch = transition_to_batch(transition)
# Process through preprocessor
processed = preprocessor(batch)
# Check that data is on CUDA
assert processed[OBS_STATE].device.type == "cuda"
assert processed[OBS_IMAGE].device.type == "cuda"
assert processed[TransitionKey.ACTION.value].device.type == "cuda"
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_pi0_processor_accelerate_scenario():
"""Test PI0 processor in simulated Accelerate scenario."""
config = create_default_config()
config.device = "cuda:0"
stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs):
pass
def __call__(self, transition):
return transition
def state_dict(self):
return {}
def load_state_dict(self, state):
pass
def reset(self):
pass
def get_config(self):
return {"tokenizer_name": "google/paligemma-3b-pt-224"}
def transform_features(self, features):
return features
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
stats,
)
# Simulate Accelerate: data already on GPU and batched
device = torch.device("cuda:0")
observation = {
OBS_STATE: torch.randn(1, 10).to(device),
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
}
action = torch.randn(1, 6).to(device)
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
batch = transition_to_batch(transition)
# Process through preprocessor
processed = preprocessor(batch)
# Check that data stays on same GPU
assert processed[OBS_STATE].device == device
assert processed[OBS_IMAGE].device == device
assert processed[TransitionKey.ACTION.value].device == device
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
def test_pi0_processor_multi_gpu():
"""Test PI0 processor with multi-GPU setup."""
config = create_default_config()
config.device = "cuda:0"
stats = create_default_stats()
# Mock the tokenizer processor to act as pass-through
class MockTokenizerProcessorStep(ProcessorStep):
def __init__(self, *args, **kwargs):
pass
def __call__(self, transition):
return transition
def state_dict(self):
return {}
def load_state_dict(self, state):
pass
def reset(self):
pass
def get_config(self):
return {"tokenizer_name": "google/paligemma-3b-pt-224"}
def transform_features(self, features):
return features
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
stats,
)
# Simulate data on different GPU
device = torch.device("cuda:1")
observation = {
OBS_STATE: torch.randn(1, 10).to(device),
OBS_IMAGE: torch.randn(1, 3, 224, 224).to(device),
}
action = torch.randn(1, 6).to(device)
transition = create_transition(observation, action, complementary_data={"task": ["test task"]})
batch = transition_to_batch(transition)
# Process through preprocessor
processed = preprocessor(batch)
# Check that data stays on cuda:1
assert processed[OBS_STATE].device == device
assert processed[OBS_IMAGE].device == device
assert processed[TransitionKey.ACTION.value].device == device
def test_pi0_processor_without_stats():
"""Test PI0 processor creation without dataset statistics."""
config = create_default_config()
# Mock the tokenizer processor
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, postprocessor = make_pi0_pre_post_processors(
config,
dataset_stats=None,
)
# Should still create processors
assert preprocessor is not None
assert postprocessor is not None
def test_pi0_newline_processor_state_dict():
"""Test Pi0NewLineProcessor state dict methods."""
processor = Pi0NewLineProcessor()
# Test state_dict (should be empty)
state = processor.state_dict()
assert state == {}
# Test load_state_dict (should do nothing)
processor.load_state_dict({})
# Test reset (should do nothing)
processor.reset()
# Test get_config
config = processor.get_config()
assert config == {}
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_pi0_processor_bfloat16_device_float32_normalizer():
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
config = create_default_config()
stats = create_default_stats()
config.device = "cuda"
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, _ = make_pi0_pre_post_processors(
config,
stats,
)
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
# Device processor converts to bfloat16
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
elif isinstance(step, NormalizerProcessorStep):
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
norm_step = step # Now type checker knows this is NormalizerProcessorStep
modified_steps.append(
NormalizerProcessorStep(
features=norm_step.features,
norm_map=norm_step.norm_map,
stats=norm_step.stats,
device=config.device,
dtype=torch.float32, # Deliberately configured as float32
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Verify initial normalizer configuration (PI0 has NormalizerProcessorStep at index 5)
normalizer_step = preprocessor.steps[5] # NormalizerProcessorStep
assert normalizer_step.dtype == torch.float32
# Create test data with both state and visual observations
observation = {
OBS_STATE: torch.randn(10, dtype=torch.float32), # PI0 expects size 10
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
}
action = torch.randn(6, dtype=torch.float32) # PI0 expects size 6
transition = create_transition(
observation, action, complementary_data={"task": "test bfloat16 adaptation"}
)
batch = transition_to_batch(transition)
# Process through full pipeline
processed = preprocessor(batch)
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
assert processed[OBS_STATE].dtype == torch.bfloat16
assert processed[OBS_IMAGE].dtype == torch.bfloat16 # IDENTITY normalization still gets dtype conversion
assert processed[TransitionKey.ACTION.value].dtype == torch.bfloat16
# Verify normalizer automatically adapted its internal state
assert normalizer_step.dtype == torch.bfloat16
# Check state stats (has normalization)
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
assert stat_tensor.dtype == torch.bfloat16
# OBS_IMAGE uses IDENTITY normalization, so no stats to check