From 1d86c9b7f2fc0bbf96db004299cb46161275327f Mon Sep 17 00:00:00 2001 From: Jade Choghari Date: Fri, 9 Jan 2026 23:08:37 +0100 Subject: [PATCH] feat(policies): add autoregressive VLAs with tokenization PiFast (#2734) --- docs/source/_toctree.yml | 2 + docs/source/pi0fast.mdx | 182 +++ pyproject.toml | 2 +- src/lerobot/policies/__init__.py | 2 + src/lerobot/policies/factory.py | 4 + src/lerobot/policies/pi0_fast/__init__.py | 21 + .../pi0_fast/configuration_pi0_fast.py | 161 ++ .../policies/pi0_fast/modeling_pi0_fast.py | 1353 +++++++++++++++++ .../policies/pi0_fast/processor_pi0_fast.py | 177 +++ .../policies/pi0_fast/train_fast_tokenizer.py | 539 +++++++ src/lerobot/processor/__init__.py | 3 +- src/lerobot/processor/tokenizer_processor.py | 266 +++- src/lerobot/utils/constants.py | 2 + src/lerobot/utils/import_utils.py | 1 + .../test_pi0_fast_original_vs_lerobot.py | 504 ++++++ 15 files changed, 3214 insertions(+), 5 deletions(-) create mode 100644 docs/source/pi0fast.mdx create mode 100644 src/lerobot/policies/pi0_fast/__init__.py create mode 100644 src/lerobot/policies/pi0_fast/configuration_pi0_fast.py create mode 100644 src/lerobot/policies/pi0_fast/modeling_pi0_fast.py create mode 100644 src/lerobot/policies/pi0_fast/processor_pi0_fast.py create mode 100644 src/lerobot/policies/pi0_fast/train_fast_tokenizer.py create mode 100644 tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 381e95dc4..2b8086cd7 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -37,6 +37,8 @@ title: SmolVLA - local: pi0 title: π₀ (Pi0) + - local: pi0fast + title: π₀-FAST (Pi0Fast) - local: pi05 title: π₀.₅ (Pi05) - local: groot diff --git a/docs/source/pi0fast.mdx b/docs/source/pi0fast.mdx new file mode 100644 index 000000000..e64355765 --- /dev/null +++ b/docs/source/pi0fast.mdx @@ -0,0 +1,182 @@ +# π₀-FAST (Pi0-FAST) + +π₀-FAST is a **Vision-Language-Action model for general robot control** that uses autoregressive next-token prediction to model continuous robot actions. + +## Model Overview + +π₀-FAST combines the power of Vision-Language Models with a novel action tokenization approach called **FAST (Frequency-space Action Sequence Tokenization)**. This enables training autoregressive VLAs on highly dexterous tasks that are impossible with standard binning-based discretization, while training **up to 5x faster** than diffusion-based approaches like π₀. + +### Why FAST? + +Standard approaches for robot action tokenization use simple per-dimension, per-timestep binning schemes. While passable for simple behaviors, this rapidly breaks down for complex and dexterous skills that require precision and high-frequency control. + +FAST solves this by compressing action sequences using signal processing techniques, resulting in a dense sequence of action tokens that can be predicted autoregressively—just like language tokens. + +### How FAST Tokenization Works + +The FAST tokenizer compresses action sequences through the following steps: + +1. **Normalize**: Take a continuous action chunk of shape `(H, D)` where `H` is the horizon and `D` is the action dimension. Normalize using one of the supported normalization methods (Quantiles recommended to handle outliers). + +2. **Discrete Cosine Transform (DCT)**: Apply DCT (via scipy) to each action dimension separately. DCT is a compression algorithm commonly used in image and audio codecs (JPEG, MP3). + +3. **Quantization**: Round and remove insignificant coefficients for each action dimension, producing a sparse frequency matrix. + +4. **Flatten**: Flatten the matrix into a 1D vector, with low-frequency components first. + +5. **Byte Pair Encoding (BPE)**: Train a BPE tokenizer to compress the DCT coefficients into dense action tokens, typically achieving **10x compression** over prior tokenization approaches. + +This approach can transform **any existing VLM** into a VLA by training it to predict these FAST tokens. + +## Installation Requirements + +1. Install LeRobot by following our [Installation Guide](./installation). +2. Install π₀-FAST dependencies by running: + + ```bash + pip install -e ".[pi]" + ``` + + > [!NOTE] + > For lerobot 0.4.0, if you want to install the pi tag, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`. + > + > This will be solved in the next patch release + +## Training a Custom FAST Tokenizer + +You have two options for the FAST tokenizer: + +1. **Use the pre-trained tokenizer**: The `physical-intelligence/fast` tokenizer was trained on 1M+ real robot action sequences and works as a general-purpose tokenizer. + +2. **Train your own tokenizer**: For maximum performance on your specific dataset, you can finetune the tokenizer on your own data. + +### Training Your Own Tokenizer + +```bash +python src/lerobot/policies/pi0_fast/train_fast_tokenizer.py \ + --repo_id "user/my-lerobot-dataset" \ + --action_horizon 10 \ + --encoded_dims "0:6" \ + --vocab_size 1024 \ + --scale 10.0 \ + --normalization_mode QUANTILES \ + --output_dir "./my_fast_tokenizer" \ + --push_to_hub \ + --hub_repo_id "username/my-action-tokenizer" +``` + +### Key Tokenizer Parameters + +| Parameter | Description | Default | +| ---------------------- | --------------------------------------------------------------------------------- | ------------ | +| `--repo_id` | LeRobot dataset repository ID | Required | +| `--action_horizon` | Number of future actions in each chunk | `10` | +| `--encoded_dims` | Comma-separated dimension ranges to encode (e.g., `"0:6,7:23"`) | `"0:6,7:23"` | +| `--vocab_size` | BPE vocabulary size | `1024` | +| `--scale` | DCT scaling factor for quantization | `10.0` | +| `--normalization_mode` | Normalization mode (`MEAN_STD`, `MIN_MAX`, `QUANTILES`, `QUANTILE10`, `IDENTITY`) | `QUANTILES` | +| `--sample_fraction` | Fraction of chunks to sample per episode | `0.1` | + +## Usage + +To use π₀-FAST in LeRobot, specify the policy type as: + +```python +policy.type=pi0_fast +``` + +## Training + +For training π₀-FAST, you can use the LeRobot training script: + +```bash +python src/lerobot/scripts/lerobot_train.py \ + --dataset.repo_id=your_dataset \ + --policy.type=pi0_fast \ + --output_dir=./outputs/pi0fast_training \ + --job_name=pi0fast_training \ + --policy.pretrained_path=lerobot/pi0_fast_base \ + --policy.dtype=bfloat16 \ + --policy.gradient_checkpointing=true \ + --policy.chunk_size=10 \ + --policy.n_action_steps=10 \ + --policy.max_action_tokens=256 \ + --steps=100000 \ + --batch_size=4 \ + --policy.device=cuda +``` + +### Key Training Parameters + +| Parameter | Description | Default | +| -------------------------------------- | -------------------------------------------------- | ---------------------------- | +| `--policy.gradient_checkpointing=true` | Reduces memory usage significantly during training | `false` | +| `--policy.dtype=bfloat16` | Use mixed precision training for efficiency | `float32` | +| `--policy.chunk_size` | Number of action steps to predict (action horizon) | `50` | +| `--policy.n_action_steps` | Number of action steps to execute | `50` | +| `--policy.max_action_tokens` | Maximum number of FAST tokens per action chunk | `256` | +| `--policy.action_tokenizer_name` | FAST tokenizer to use | `physical-intelligence/fast` | +| `--policy.compile_model=true` | Enable torch.compile for faster training | `false` | + +## Inference + +### KV-Caching for Fast Inference + +π₀-FAST supports **KV-caching**, a widely used optimization in LLM inference. This caches the key-value pairs from the attention mechanism, avoiding redundant computation during autoregressive decoding. + +```python +# KV-caching is enabled by default +policy.use_kv_cache=true +``` + +### Inference Example + +```python +from lerobot.policies.pi0_fast import PI0FastPolicy, PI0FastConfig + +# Load the policy +policy = PI0FastPolicy.from_pretrained("your-model-path") + +# During inference +actions = policy.predict_action_chunk(batch) +``` + +## Model Architecture + +π₀-FAST uses a PaliGemma-based architecture: + +- **Vision Encoder**: SigLIP vision tower for image understanding +- **Language Model**: Gemma 2B for processing language instructions and predicting action tokens + +The model takes images, text instructions, and robot state as input, and outputs discrete FAST tokens that are decoded back to continuous actions. + +## Configuration Options + +| Parameter | Description | Default | +| -------------------- | ----------------------------------------------- | ---------- | +| `paligemma_variant` | VLM backbone variant (`gemma_300m`, `gemma_2b`) | `gemma_2b` | +| `max_state_dim` | Maximum state vector dimension (padded) | `32` | +| `max_action_dim` | Maximum action vector dimension (padded) | `32` | +| `temperature` | Sampling temperature (0.0 for greedy) | `0.0` | +| `max_decoding_steps` | Maximum decoding steps | `256` | +| `use_kv_cache` | Enable KV caching for faster inference | `true` | + +## Comparison with π₀ + +| Feature | π₀ | π₀-FAST | +| --------------------- | ------------------------- | ---------------------------- | +| Action Representation | Flow Matching (Diffusion) | Autoregressive Tokens (FAST) | +| Training Speed | 1x | **5x faster** | +| Dexterity | High | High | +| Inference Method | Iterative Denoising | Autoregressive Decoding | +| KV-Caching | N/A | Supported | + +## License + +This model follows the **Apache 2.0 License**, consistent with the original [OpenPI repository](https://github.com/Physical-Intelligence/openpi). + +## References + +- [FAST: Efficient Robot Action Tokenization](https://www.physicalintelligence.company/research/fast) - Physical Intelligence Blog +- [OpenPI Repository](https://github.com/Physical-Intelligence/openpi) - Original implementation +- [FAST Tokenizer on Hugging Face](https://huggingface.co/physical-intelligence/fast) - Pre-trained tokenizer diff --git a/pyproject.toml b/pyproject.toml index f25645319..75738d2de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,7 +127,7 @@ wallx = [ "torchdiffeq==0.2.5", "qwen_vl_utils==0.0.11" ] -pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi"] +pi = ["transformers @ git+https://github.com/huggingface/transformers.git@fix/lerobot_openpi", "scipy>=1.10.1,<1.15"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"] groot = [ "lerobot[transformers-dep]", diff --git a/src/lerobot/policies/__init__.py b/src/lerobot/policies/__init__.py index 99275e787..c7951f028 100644 --- a/src/lerobot/policies/__init__.py +++ b/src/lerobot/policies/__init__.py @@ -16,6 +16,7 @@ from .act.configuration_act import ACTConfig as ACTConfig from .diffusion.configuration_diffusion import DiffusionConfig as DiffusionConfig from .groot.configuration_groot import GrootConfig as GrootConfig from .pi0.configuration_pi0 import PI0Config as PI0Config +from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig from .pi05.configuration_pi05 import PI05Config as PI05Config from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig from .smolvla.processor_smolvla import SmolVLANewLineProcessor @@ -29,6 +30,7 @@ __all__ = [ "DiffusionConfig", "PI0Config", "PI05Config", + "PI0FastConfig", "SmolVLAConfig", "SARMConfig", "TDMPCConfig", diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 8c414f235..fff08ad37 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -91,6 +91,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]: from lerobot.policies.pi0.modeling_pi0 import PI0Policy return PI0Policy + elif name == "pi0_fast": + from lerobot.policies.pi0_fast.modeling_pi0_fast import PI0FastPolicy + + return PI0FastPolicy elif name == "pi05": from lerobot.policies.pi05.modeling_pi05 import PI05Policy diff --git a/src/lerobot/policies/pi0_fast/__init__.py b/src/lerobot/policies/pi0_fast/__init__.py new file mode 100644 index 000000000..a0277da0f --- /dev/null +++ b/src/lerobot/policies/pi0_fast/__init__.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .configuration_pi0_fast import PI0FastConfig +from .modeling_pi0_fast import PI0FastPolicy +from .processor_pi0_fast import make_pi0_fast_pre_post_processors + +__all__ = ["PI0FastConfig", "PI0FastPolicy", "make_pi0_fast_pre_post_processors"] diff --git a/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py b/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py new file mode 100644 index 000000000..42aa4a132 --- /dev/null +++ b/src/lerobot/policies/pi0_fast/configuration_pi0_fast.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.optim.optimizers import AdamWConfig +from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig +from lerobot.policies.rtc.configuration_rtc import RTCConfig + +DEFAULT_IMAGE_SIZE = 224 + + +@PreTrainedConfig.register_subclass("pi0_fast") +@dataclass +class PI0FastConfig(PreTrainedConfig): + paligemma_variant: str = "gemma_2b" + action_expert_variant: str = "gemma_300m" + dtype: str = "float32" # Options: "bfloat16", "float32" + + chunk_size: int = 50 # Number of action steps to predict, in openpi called "action_horizon" + n_action_steps: int = 50 # Number of action steps to execute + + # Shorter state and action vectors will be padded to these dimensions + max_state_dim: int = 32 + max_action_dim: int = 32 + max_action_tokens: int = 256 + + # Real-Time Chunking (RTC) configuration + rtc_config: RTCConfig | None = None + + image_resolution: tuple[int, int] = ( + DEFAULT_IMAGE_SIZE, + DEFAULT_IMAGE_SIZE, + ) # see openpi `preprocessing_pytorch.py` + + # Add empty images. Used to add empty cameras when no image features are present. + empty_cameras: int = 0 + + tokenizer_max_length: int = 200 # see openpi `__post_init__` + text_tokenizer_name: str = "google/paligemma-3b-pt-224" + action_tokenizer_name: str = "physical-intelligence/fast" + temperature: float = 0.0 + max_decoding_steps: int = 256 + fast_skip_tokens: int = 128 + + # Whether to validate that decoded action tokens start with "Action: " prefix + validate_action_token_prefix: bool = True + + # Whether to use KV cache for faster autoregressive decoding + use_kv_cache: bool = True + + normalization_mapping: dict[str, NormalizationMode] = field( + default_factory=lambda: { + "VISUAL": NormalizationMode.IDENTITY, + "STATE": NormalizationMode.MEAN_STD, # Pi0Fast uses quantiles for state + "ACTION": NormalizationMode.MEAN_STD, # Pi0Fast uses quantiles for action + } + ) + + # Training settings + gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization + compile_model: bool = False # Whether to use torch.compile for model optimization + compile_mode: str = "max-autotune" # Torch compile mode + device: str | None = None # Device to use for the model (None = auto-detect) + + # Optimizer settings: see openpi `AdamW` + optimizer_lr: float = 2.5e-5 # see openpi `CosineDecaySchedule: peak_lr` + optimizer_betas: tuple[float, float] = (0.9, 0.95) + optimizer_eps: float = 1e-8 + optimizer_weight_decay: float = 0.01 + optimizer_grad_clip_norm: float = 1.0 + + # Scheduler settings: see openpi `CosineDecaySchedule` + # Note: These will auto-scale if --steps < scheduler_decay_steps + # For example, --steps=3000 will scale warmup to 100 and decay to 3000 + scheduler_warmup_steps: int = 1_000 + scheduler_decay_steps: int = 30_000 + scheduler_decay_lr: float = 2.5e-6 + + def __post_init__(self): + super().__post_init__() + + # Validate configuration + if self.n_action_steps > self.chunk_size: + raise ValueError( + f"n_action_steps ({self.n_action_steps}) cannot be greater than chunk_size ({self.chunk_size})" + ) + + if self.paligemma_variant not in ["gemma_300m", "gemma_2b"]: + raise ValueError(f"Invalid paligemma_variant: {self.paligemma_variant}") + + if self.dtype not in ["bfloat16", "float32"]: + raise ValueError(f"Invalid dtype: {self.dtype}") + + def validate_features(self) -> None: + """Validate and set up input/output features.""" + for i in range(self.empty_cameras): + key = f"observation.images.empty_camera_{i}" + empty_camera = PolicyFeature( + type=FeatureType.VISUAL, + shape=(3, *self.image_resolution), # Use configured image resolution + ) + self.input_features[key] = empty_camera + + if "observation.state" not in self.input_features: + state_feature = PolicyFeature( + type=FeatureType.STATE, + shape=(self.max_state_dim,), # Padded to max_state_dim + ) + self.input_features["observation.state"] = state_feature + + if "action" not in self.output_features: + action_feature = PolicyFeature( + type=FeatureType.ACTION, + shape=(self.max_action_dim,), # Padded to max_action_dim + ) + self.output_features["action"] = action_feature + + def get_optimizer_preset(self) -> AdamWConfig: + return AdamWConfig( + lr=self.optimizer_lr, + betas=self.optimizer_betas, + eps=self.optimizer_eps, + weight_decay=self.optimizer_weight_decay, + grad_clip_norm=self.optimizer_grad_clip_norm, + ) + + def get_scheduler_preset(self): + return CosineDecayWithWarmupSchedulerConfig( + peak_lr=self.optimizer_lr, + decay_lr=self.scheduler_decay_lr, + num_warmup_steps=self.scheduler_warmup_steps, + num_decay_steps=self.scheduler_decay_steps, + ) + + @property + def observation_delta_indices(self) -> None: + return None + + @property + def action_delta_indices(self) -> list: + return list(range(self.chunk_size)) + + @property + def reward_delta_indices(self) -> None: + return None diff --git a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py new file mode 100644 index 000000000..b4bc7ba22 --- /dev/null +++ b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py @@ -0,0 +1,1353 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import builtins +import logging +import math +from collections import deque +from pathlib import Path +from typing import TYPE_CHECKING, Literal, TypedDict + +import numpy as np +import torch +import torch.nn.functional as F # noqa: N812 +from torch import Tensor, nn +from typing_extensions import Unpack + +from lerobot.utils.import_utils import _scipy_available, _transformers_available + +# Conditional import for type checking and lazy loading +if TYPE_CHECKING or _scipy_available: + from scipy.fftpack import idct +else: + idct = None + +if TYPE_CHECKING or _transformers_available: + from transformers import AutoTokenizer + from transformers.models.auto import CONFIG_MAPPING + from transformers.models.paligemma.modeling_paligemma import PaliGemmaForConditionalGeneration +else: + CONFIG_MAPPING = None + PaliGemmaForConditionalGeneration = None + AutoTokenizer = None + +from lerobot.configs.policies import PreTrainedConfig +from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig +from lerobot.policies.pretrained import PreTrainedPolicy, T +from lerobot.policies.rtc.modeling_rtc import RTCProcessor +from lerobot.utils.constants import ( + ACTION, + ACTION_TOKEN_MASK, + ACTION_TOKENS, + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, + OPENPI_ATTENTION_MASK_VALUE, +) + + +class ActionSelectKwargs(TypedDict, total=False): + temperature: float | None + + +def pad_vector(vector, new_dim): + """Pad the last dimension of a vector to new_dim with zeros. + + Can be (batch_size x sequence_length x features_dimension) + or (batch_size x features_dimension) + """ + if vector.shape[-1] >= new_dim: + return vector + return F.pad(vector, (0, new_dim - vector.shape[-1])) + + +def resize_with_pad_torch( # see openpi `resize_with_pad_torch` (exact copy) + images: torch.Tensor, + height: int, + width: int, + mode: str = "bilinear", +) -> torch.Tensor: + """PyTorch version of resize_with_pad. Resizes an image to a target height and width without distortion + by padding with black. If the image is float32, it must be in the range [-1, 1]. + + Args: + images: Tensor of shape [*b, h, w, c] or [*b, c, h, w] + height: Target height + width: Target width + mode: Interpolation mode ('bilinear', 'nearest', etc.) + + Returns: + Resized and padded tensor with same shape format as input + """ + # Check if input is in channels-last format [*b, h, w, c] or channels-first [*b, c, h, w] + if images.shape[-1] <= 4: # Assume channels-last format + channels_last = True + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + images = images.permute(0, 3, 1, 2) # [b, h, w, c] -> [b, c, h, w] + else: + channels_last = False + if images.dim() == 3: + images = images.unsqueeze(0) # Add batch dimension + + batch_size, channels, cur_height, cur_width = images.shape + + # Calculate resize ratio + ratio = max(cur_width / width, cur_height / height) + resized_height = int(cur_height / ratio) + resized_width = int(cur_width / ratio) + + # Resize + resized_images = F.interpolate( + images, + size=(resized_height, resized_width), + mode=mode, + align_corners=False if mode == "bilinear" else None, + ) + + # Handle dtype-specific clipping + if images.dtype == torch.uint8: + resized_images = torch.round(resized_images).clamp(0, 255).to(torch.uint8) + elif images.dtype == torch.float32: + resized_images = resized_images.clamp(-1.0, 1.0) + else: + raise ValueError(f"Unsupported image dtype: {images.dtype}") + + # Calculate padding + pad_h0, remainder_h = divmod(height - resized_height, 2) + pad_h1 = pad_h0 + remainder_h + pad_w0, remainder_w = divmod(width - resized_width, 2) + pad_w1 = pad_w0 + remainder_w + + # Pad + constant_value = 0 if images.dtype == torch.uint8 else -1.0 + padded_images = F.pad( + resized_images, + (pad_w0, pad_w1, pad_h0, pad_h1), # left, right, top, bottom + mode="constant", + value=constant_value, + ) + + # Convert back to original format if needed + if channels_last: + padded_images = padded_images.permute(0, 2, 3, 1) # [b, c, h, w] -> [b, h, w, c] + + return padded_images + + +class GemmaConfig: # see openpi `gemma.py: Config` + """Configuration for Gemma model variants.""" + + def __init__(self, width, depth, mlp_dim, num_heads, num_kv_heads, head_dim): + self.width = width + self.depth = depth + self.mlp_dim = mlp_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + + +def get_gemma_config(variant: str) -> GemmaConfig: # see openpi `gemma.py: get_config` + """Returns config for specified gemma variant.""" + if variant == "gemma_300m": + return GemmaConfig( + width=1024, + depth=18, + mlp_dim=4096, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + elif variant == "gemma_2b": + return GemmaConfig( + width=2048, + depth=18, + mlp_dim=16_384, + num_heads=8, + num_kv_heads=1, + head_dim=256, + ) + else: + raise ValueError(f"Unknown variant: {variant}") + + +class PI0FastPaliGemma(nn.Module): + """PaliGemma model for PI0Fast""" + + def __init__( + self, + vlm_config, + use_adarms=None, + precision: Literal["bfloat16", "float32"] = "bfloat16", + ): + if use_adarms is None: + use_adarms = [False, False] + super().__init__() + + vlm_config_hf = CONFIG_MAPPING["paligemma"]() + vlm_config_hf._vocab_size = 257152 # noqa: SLF001 + vlm_config_hf.image_token_index = 257152 + vlm_config_hf.text_config.hidden_size = vlm_config.width + vlm_config_hf.text_config.intermediate_size = vlm_config.mlp_dim + vlm_config_hf.text_config.num_attention_heads = vlm_config.num_heads + vlm_config_hf.text_config.head_dim = vlm_config.head_dim + vlm_config_hf.text_config.num_hidden_layers = vlm_config.depth + vlm_config_hf.text_config.num_key_value_heads = vlm_config.num_kv_heads + vlm_config_hf.text_config.hidden_activation = "gelu_pytorch_tanh" + vlm_config_hf.text_config.torch_dtype = "float32" + vlm_config_hf.text_config.vocab_size = 257152 + vlm_config_hf.text_config.use_adarms = use_adarms[0] + vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None + vlm_config_hf.vision_config.intermediate_size = 4304 + vlm_config_hf.vision_config.projection_dim = 2048 + vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast" + vlm_config_hf.vision_config.torch_dtype = "float32" + + self.paligemma = PaliGemmaForConditionalGeneration(config=vlm_config_hf) + + self.to_bfloat16_for_selected_params(precision) + + def to_bfloat16_for_selected_params(self, precision: Literal["bfloat16", "float32"] = "bfloat16"): + if precision == "bfloat16": + self.to(dtype=torch.bfloat16) + elif precision == "float32": + self.to(dtype=torch.float32) + return + else: + raise ValueError(f"Invalid precision: {precision}") + + params_to_keep_float32 = [ + "vision_tower.vision_model.embeddings.patch_embedding.weight", + "vision_tower.vision_model.embeddings.patch_embedding.bias", + "vision_tower.vision_model.embeddings.position_embedding.weight", + "input_layernorm", + "post_attention_layernorm", + "model.norm", + ] + + for name, param in self.named_parameters(): + if any(selector in name for selector in params_to_keep_float32): + param.data = param.data.to(dtype=torch.float32) + + def embed_image(self, image: torch.Tensor): + return self.paligemma.model.get_image_features(image) + + def embed_language_tokens(self, tokens: torch.Tensor): + return self.paligemma.language_model.embed_tokens(tokens) + + def forward( + self, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: list[torch.FloatTensor] | None = None, + inputs_embeds: list[torch.FloatTensor] | None = None, + use_cache: bool | None = None, + adarms_cond: list[torch.Tensor] | None = None, + ): + if adarms_cond is None: + adarms_cond = [None, None] + if inputs_embeds[1] is None: + prefix_output = self.paligemma.language_model.forward( + inputs_embeds=inputs_embeds[0], + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + adarms_cond=adarms_cond[0] if adarms_cond is not None else None, + ) + prefix_past_key_values = prefix_output.past_key_values + # prefix_output to be used for the language head + # shape: [batch_size, seq_len, hidden_size] with hidden_size = 2048 + prefix_output = prefix_output.last_hidden_state + suffix_output = None + return [prefix_output, suffix_output], prefix_past_key_values + + +class PI0FastPytorch(nn.Module): # see openpi `PI0Pytorch` + """Core PI0Fast PyTorch model.""" + + def __init__( + self, + config: PI0FastConfig, + rtc_processor: RTCProcessor | None = None, + paligemma_tokenizer: "AutoTokenizer | None" = None, + ): + super().__init__() + self.config = config + self.rtc_processor = rtc_processor + self._paligemma_tokenizer = paligemma_tokenizer + + paligemma_config = get_gemma_config(config.paligemma_variant) + + self.paligemma_with_expert = PI0FastPaliGemma( + paligemma_config, + use_adarms=[False, True], + precision=config.dtype, + ) + + # Initialize gradient checkpointing flag + self.gradient_checkpointing_enabled = False + + # Compile model if requested + if config.compile_model: + torch.set_float32_matmul_precision("high") + self.sample_actions_fast = torch.compile(self.sample_actions_fast, mode=config.compile_mode) + self.forward = torch.compile(self.forward, mode=config.compile_mode) + + msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues""" + + try: + from transformers.models.siglip import check + + if not check.check_whether_transformers_replace_is_installed_correctly(): + raise ValueError(msg) + except ImportError: + raise ValueError(msg) from None + + def gradient_checkpointing_enable(self): + """Enable gradient checkpointing for memory optimization.""" + self.gradient_checkpointing_enabled = True + # Call the proper gradient_checkpointing_enable() method with use_reentrant=False for better memory efficiency + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": False} + ) + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": False} + ) + logging.info("Enabled gradient checkpointing for PI0FastPytorch model") + + def gradient_checkpointing_disable(self): + """Disable gradient checkpointing.""" + self.gradient_checkpointing_enabled = False + # Call the proper gradient_checkpointing_disable() method + self.paligemma_with_expert.paligemma.language_model.gradient_checkpointing_disable() + self.paligemma_with_expert.paligemma.vision_tower.gradient_checkpointing_disable() + logging.info("Disabled gradient checkpointing for PI0FastPytorch model") + + def _apply_checkpoint(self, func, *args, **kwargs): + """Helper method to apply gradient checkpointing if enabled.""" + if self.gradient_checkpointing_enabled and self.training: + return torch.utils.checkpoint.checkpoint( + func, *args, use_reentrant=False, preserve_rng_state=False, **kwargs + ) + return func(*args, **kwargs) + + def _prepare_attention_masks_4d(self, att_2d_masks, dtype=None): + """Helper method to prepare 4D attention masks for transformer.""" + att_2d_masks_4d = att_2d_masks[:, None, :, :] + result = torch.where(att_2d_masks_4d, 0.0, OPENPI_ATTENTION_MASK_VALUE) + if dtype is not None: + result = result.to(dtype=dtype) + return result + + def embed_prefix_fast( + self, + images, + img_masks, + tokens, + masks, + fast_action_tokens=None, + fast_action_masks=None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]: + """Embed images, language tokens, and FAST action tokens. + + Attention pattern: + - Images + Language: bidirectional among themselves + - FAST: attend to images + language, causal among themselves + + Args: + images: List of image tensors + img_masks: List of image masks + tokens: Language instruction tokens + masks: Attention masks for tokens + fast_action_tokens: FAST action tokens (discrete token IDs) + fast_action_masks: Padding masks for FAST action tokens + + Returns: + embs: Concatenated embeddings [images, tokens, fast_action_tokens] + pad_masks: Padding masks + att_masks: 2D attention mask + total_T_images: Total number of image tokens + num_fast_embs: Number of FAST action token embeddings + """ + embs = [] + pad_masks = [] + att_mask_segments = [] + total_t_images = 0 + num_fast_embs = 0 + + # Process images + for img, img_mask in zip(images, img_masks, strict=True): + + def image_embed_func(img): + return self.paligemma_with_expert.embed_image(img) + + img_emb = self._apply_checkpoint(image_embed_func, img) + bsize, num_img_embs = img_emb.shape[:2] + + embs.append(img_emb) + pad_masks.append(img_mask[:, None].expand(bsize, num_img_embs)) + att_mask_segments.append(("image", num_img_embs)) + total_t_images += num_img_embs + + # Process language instruction tokens + def lang_embed_func(tokens): + lang_emb = self.paligemma_with_expert.embed_language_tokens(tokens) + lang_emb_dim = lang_emb.shape[-1] + return lang_emb * math.sqrt(lang_emb_dim) + + lang_emb = self._apply_checkpoint(lang_embed_func, tokens) + embs.append(lang_emb) + pad_masks.append(masks) + + num_lang_embs = lang_emb.shape[1] + att_mask_segments.append(("language", num_lang_embs)) + + # Process FAST action tokens (discrete token IDs) + if fast_action_tokens is not None: + + def fast_action_embed_func(fast_action_tokens): + fast_emb = self.paligemma_with_expert.embed_language_tokens(fast_action_tokens) + fast_emb_dim = fast_emb.shape[-1] + return fast_emb * math.sqrt(fast_emb_dim) + + fast_action_emb = self._apply_checkpoint(fast_action_embed_func, fast_action_tokens) + embs.append(fast_action_emb) + + num_fast_embs = fast_action_tokens.shape[1] + pad_masks.append(fast_action_masks) + att_mask_segments.append(("fast", num_fast_embs)) + + embs = torch.cat(embs, dim=1) + pad_masks = torch.cat(pad_masks, dim=1) + + # Create custom 2D attention mask: + # - Images + Language: bidirectional among themselves + # - FAST: attend to images + language, causal among themselves + att_masks = self._create_custom_attention_mask_fast(att_mask_segments, pad_masks, bsize) + + return embs, pad_masks, att_masks, total_t_images, num_fast_embs + + def _create_custom_attention_mask_fast(self, att_mask_segments, pad_masks, bsize): + """Create custom 2D attention mask. + + Attention rules: + - Images + Language: bidirectional among themselves + - FAST: attend to images + language, causal among themselves + """ + total_len = sum(length for _, length in att_mask_segments) + device = pad_masks.device + + att_2d_masks = torch.zeros(bsize, total_len, total_len, dtype=torch.bool, device=device) + + positions = [] + current_pos = 0 + for seg_type, seg_len in att_mask_segments: + positions.append((seg_type, current_pos, current_pos + seg_len)) + current_pos += seg_len + + for _i, (query_type, query_start, query_end) in enumerate(positions): + for _j, (key_type, key_start, key_end) in enumerate(positions): + # Images and Language can attend to each other bidirectionally + if ( + query_type in ["image", "language"] + and key_type in ["image", "language"] + or query_type == "fast" + and key_type in ["image", "language"] + ): + att_2d_masks[:, query_start:query_end, key_start:key_end] = True + + # FAST tokens attend causally to themselves + elif query_type == "fast" and key_type == "fast": + fast_len = query_end - query_start + causal_mask = torch.tril(torch.ones(fast_len, fast_len, dtype=torch.bool, device=device)) + att_2d_masks[:, query_start:query_end, key_start:key_end] = causal_mask[None, :, :] + + # Apply padding masks + pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None] + att_2d_masks = att_2d_masks & pad_2d_masks + + return att_2d_masks + + def forward( + self, + images, + img_masks, + tokens, + masks, + fast_action_tokens, + fast_action_masks, + ) -> dict: + """Forward pass for PI0Fast. + + This implements the Pi0FAST training objective: predict next action token + using cross-entropy loss. + + Args: + images: List of image tensors + img_masks: List of image masks + tokens: Language instruction tokens + masks: Attention masks for tokens + fast_action_tokens: Discrete action token IDs [B, max_action_tokens] + fast_action_masks: Padding masks for fast action tokens [B, max_action_tokens] + + Returns: + Dictionary with 'fast_loss' and 'loss' keys + """ + if fast_action_tokens is None or fast_action_masks is None: + raise ValueError("fast_action_tokens and fast_action_masks are required for FAST-only mode") + + # Embed prefix with FAST tokens + prefix_embs, prefix_pad_masks, prefix_att_masks, total_t_images, num_fast_embs = ( + self.embed_prefix_fast( + images, + img_masks, + tokens, + masks, + fast_action_tokens=fast_action_tokens, + fast_action_masks=fast_action_masks, + ) + ) + + # Convert embeddings to bfloat16 if needed + if ( + self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + == torch.bfloat16 + ): + prefix_embs = prefix_embs.to(dtype=torch.bfloat16) + + # for next-token prediction, input tokens [0:T-1] to predict tokens [1:T] + input_embs = prefix_embs + input_pad_masks = prefix_pad_masks + input_att_masks = prefix_att_masks + + position_ids = torch.cumsum(input_pad_masks, dim=1) - 1 + att_2d_4d = self._prepare_attention_masks_4d(input_att_masks, dtype=input_embs.dtype) + + # forward pass through paligemma (language model) + (prefix_out, _), _ = self.paligemma_with_expert.forward( + attention_mask=att_2d_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[input_embs, None], # No suffix/action expert + use_cache=False, + adarms_cond=[None, None], + ) + + # Get logits for FAST action tokens using the FAST LM head + # only compute logits for the positions that predict FAST tokens + lm_head = self.paligemma_with_expert.paligemma.lm_head + + # Targets are the FAST action tokens + fast_targets = fast_action_tokens # (B, num_fast_embs) + + # extract logits for FAST token prediction + fast_hidden = prefix_out[:, -fast_targets.shape[1] :, :] + fast_logits_for_pred = lm_head(fast_hidden) # (B, num_fast_embs, gemma_vocab_size) + + # Shift left for next-step prediction and shift target + # logits[:, i] predicts targets[:, i+1] + fast_logits_for_pred = fast_logits_for_pred[:, :-1, :] # shift logits left + fast_targets = fast_targets[:, 1:] # shift targets right + fast_action_masks = fast_action_masks[:, 1:] # shift masks to match targets + + # compute cross-entropy loss + loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + fast_logits_flat = fast_logits_for_pred.reshape(-1, fast_logits_for_pred.size(-1)) + fast_targets_flat = fast_targets.reshape(-1) + + fast_loss_per_token = loss_fct(fast_logits_flat, fast_targets_flat) + fast_loss_per_token = fast_loss_per_token.reshape(fast_targets.shape) + + # apply mask and compute mean loss + masked_fast_loss = fast_loss_per_token * fast_action_masks.float() + fast_loss = masked_fast_loss.sum() / fast_action_masks.sum().clamp(min=1) + + return { + "ce_loss": fast_loss, + "loss": fast_loss, + } + + @torch.no_grad() + def sample_actions_fast( + self, + images, + img_masks, + tokens, + masks, + max_decoding_steps=None, + temperature=0.0, + ) -> torch.Tensor: + """ + Inefficient but safe autoregressive decoding for FAST tokens. + Matches the pattern of _generate_subtask_tokens. + TODO: jadechoghari, should we move this logic to PI0FastPolicy class? + """ + if max_decoding_steps is None: + max_decoding_steps = self.config.max_action_tokens + + bsize = tokens.shape[0] + device = tokens.device + lm_head = self.paligemma_with_expert.paligemma.lm_head + + # add bos token after tokens + bos_token = torch.full( + (bsize, 1), self._paligemma_tokenizer.bos_token_id, dtype=torch.long, device=device + ) + tokens = torch.cat([tokens, bos_token], dim=1) + masks = torch.cat([masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1) + + # 1. Initial Embedding (matches training prefix) + # prefix_embs will include [Images, Language Prompt, BOS] + prefix_embs, prefix_pad_masks, prefix_att_masks, total_t_images, _ = self.embed_prefix_fast( + images, img_masks, tokens, masks, fast_action_tokens=None, fast_action_masks=None + ) + + if ( + self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + == torch.bfloat16 + ): + prefix_embs = prefix_embs.to(dtype=torch.bfloat16) + + generated_action_tokens = torch.zeros((bsize, max_decoding_steps), dtype=torch.long, device=device) + + # 2. Decoding Loop (each step re-computes full sequence) + for t in range(max_decoding_steps): + # always re-calculate position IDs from the current pad mask + position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + att_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype) + + # full forward pass (no kv cache) + (prefix_out, _), _ = self.paligemma_with_expert.forward( + attention_mask=att_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=False, + adarms_cond=[None, None], + ) + + # predict next token from the very last sequence position + last_logits = lm_head(prefix_out[:, -1:, :]) # (B, 1, vocab_size) + + if temperature > 0: + probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(last_logits[:, -1], dim=-1, keepdim=True) + + generated_action_tokens[:, t] = next_token.squeeze(-1) + + # 3. Update sequence for next iteration (unless it's the last step) + if t < max_decoding_steps - 1: + # embed the newly generated token + next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token) + next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1]) + if prefix_embs.dtype == torch.bfloat16: + next_token_emb = next_token_emb.to(dtype=torch.bfloat16) + + # append to embeddings + prefix_embs = torch.cat([prefix_embs, next_token_emb], dim=1) + + # update padding mask (new token is always valid/1) + prefix_pad_masks = torch.cat( + [prefix_pad_masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1 + ) + + # update 2d attention mask: grow the matrix + old_len = prefix_att_masks.shape[1] + new_len = old_len + 1 + new_att_masks = torch.zeros((bsize, new_len, new_len), dtype=torch.bool, device=device) + new_att_masks[:, :old_len, :old_len] = prefix_att_masks + # new token attends to all non-padding tokens in the updated sequence + new_att_masks[:, -1, :] = prefix_pad_masks + prefix_att_masks = new_att_masks + return generated_action_tokens + + @torch.no_grad() + def sample_actions_fast_kv_cache( + self, + images, + img_masks, + tokens, + masks, + max_decoding_steps=None, + temperature=0.0, + ) -> torch.Tensor: + """ + Optimized autoregressive decoding for FAST tokens using KV Caching. + """ + if max_decoding_steps is None: + max_decoding_steps = self.config.max_action_tokens + + bsize = tokens.shape[0] + device = tokens.device + lm_head = self.paligemma_with_expert.paligemma.lm_head + + # --- 1. PREFILL PHASE --- + # Process Images + Text Prompt + BOS token once to populate the KV cache. + + # Add BOS token to the prompt + bos_token = torch.full( + (bsize, 1), self._paligemma_tokenizer.bos_token_id, dtype=torch.long, device=device + ) + tokens_in = torch.cat([tokens, bos_token], dim=1) + masks_in = torch.cat([masks, torch.ones((bsize, 1), dtype=torch.bool, device=device)], dim=1) + + # Embed prefix [Images, Language, BOS] + # fast_action_tokens=None means we are just embedding the condition (images+text) + prefix_embs, prefix_pad_masks, prefix_att_masks, total_t_images, _ = self.embed_prefix_fast( + images, img_masks, tokens_in, masks_in, fast_action_tokens=None, fast_action_masks=None + ) + + # Ensure correct precision (bfloat16/float32) + if ( + self.paligemma_with_expert.paligemma.language_model.layers[0].self_attn.q_proj.weight.dtype + == torch.bfloat16 + ): + prefix_embs = prefix_embs.to(dtype=torch.bfloat16) + + # Create position IDs (cumsum of mask - 1) + position_ids = torch.cumsum(prefix_pad_masks, dim=1) - 1 + + # Create 4D mask for the prefix + att_4d = self._prepare_attention_masks_4d(prefix_att_masks, dtype=prefix_embs.dtype) + + # Forward pass (Prefill) with use_cache=True + # We only pass [prefix_embs, None] because we aren't using the suffix (expert) model yet + (prefix_out, _), past_key_values = self.paligemma_with_expert.forward( + attention_mask=att_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=True, # Enable caching + adarms_cond=[None, None], + ) + + # Sample the first action token from the last logit of the prefix + last_logits = lm_head(prefix_out[:, -1:, :]) # (B, 1, V) + if temperature > 0: + probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(last_logits[:, -1], dim=-1, keepdim=True) + + # Initialize storage for generated tokens + generated_action_tokens = torch.zeros((bsize, max_decoding_steps), dtype=torch.long, device=device) + generated_action_tokens[:, 0] = next_token.squeeze(-1) + + # Track valid tokens mask (0 for pad, 1 for valid) + # We need this to tell the new token what it can attend to (images + text + past actions) + current_pad_mask = prefix_pad_masks + + # --- 2. DECODING PHASE --- + # Generate remaining tokens one by one using the cache. + + for t in range(1, max_decoding_steps): + # Embed the single previous token + # We use embed_language_tokens directly to avoid overhead of full prefix embedding + next_token_emb = self.paligemma_with_expert.embed_language_tokens(next_token) + next_token_emb = next_token_emb * math.sqrt(next_token_emb.shape[-1]) + if prefix_embs.dtype == torch.bfloat16: + next_token_emb = next_token_emb.to(dtype=torch.bfloat16) + + # Update Pad Mask: append 1s for the new valid token + new_column = torch.ones((bsize, 1), dtype=torch.bool, device=device) + current_pad_mask = torch.cat([current_pad_mask, new_column], dim=1) + + # Update Position IDs for the single new token + current_position_ids = (torch.sum(current_pad_mask, dim=1, keepdim=True) - 1).long() + + # Create Attention Mask for the single new step + # The new token attends to all valid tokens in history (captured by current_pad_mask). + # Shape becomes (B, 1, 1, Total_Len) which works with HF's cache logic. + step_att_mask = self._prepare_attention_masks_4d( + current_pad_mask.unsqueeze(1), dtype=next_token_emb.dtype + ) + + # Forward pass (Decoding step) + # input_embeds is just the new token (B, 1, D) + (step_out, _), past_key_values = self.paligemma_with_expert.forward( + attention_mask=step_att_mask, + position_ids=current_position_ids, + past_key_values=past_key_values, # Pass updated cache + inputs_embeds=[next_token_emb, None], + use_cache=True, + adarms_cond=[None, None], + ) + + # Sample next token + last_logits = lm_head(step_out[:, -1:, :]) + if temperature > 0: + probs = torch.softmax(last_logits[:, -1] / temperature, dim=-1) + next_token = torch.multinomial(probs, num_samples=1) + else: + next_token = torch.argmax(last_logits[:, -1], dim=-1, keepdim=True) + + generated_action_tokens[:, t] = next_token.squeeze(-1) + + return generated_action_tokens + + +class PI0FastPolicy(PreTrainedPolicy): + """PI0Fast Policy for LeRobot.""" + + config_class = PI0FastConfig + name = "pi0_fast" + + def __init__( + self, + config: PI0FastConfig, + **kwargs, + ): + """ + Args: + config: Policy configuration class instance. + """ + super().__init__(config) + config.validate_features() + self.config = config + + # Load tokenizers first + try: + from transformers import AutoProcessor, AutoTokenizer + + # Load FAST tokenizer + self.action_tokenizer = AutoProcessor.from_pretrained( + config.action_tokenizer_name, trust_remote_code=True + ) + + # Load PaliGemma tokenizer for token conversion + self._paligemma_tokenizer = AutoTokenizer.from_pretrained( + config.text_tokenizer_name, trust_remote_code=True, add_eos_token=True, add_bos_token=False + ) + + logging.info("Loaded FAST tokenizer for action detokenization") + except Exception as e: + logging.error(f"Failed to load FAST tokenizer for action detokenization: {e}") + logging.error("Tokenizer loading is required for proper policy initialization; aborting.") + raise RuntimeError("Failed to load required tokenizers for PI0FastPolicy initialization") from e + + # Initialize the core PI0Fast model + self.init_rtc_processor() + self.model = PI0FastPytorch( + config, rtc_processor=self.rtc_processor, paligemma_tokenizer=self._paligemma_tokenizer + ) + + # Enable gradient checkpointing if requested + if config.gradient_checkpointing: + self.model.gradient_checkpointing_enable() + + self.model.to(config.device) + + self.reset() + + @classmethod + def from_pretrained( + cls: builtins.type[T], + pretrained_name_or_path: str | Path, + *, + config: PreTrainedConfig | None = None, + force_download: bool = False, + resume_download: bool | None = None, + proxies: dict | None = None, + token: str | bool | None = None, + cache_dir: str | Path | None = None, + local_files_only: bool = False, + revision: str | None = None, + strict: bool = True, + **kwargs, + ) -> T: + """Override the from_pretrained method to handle key remapping and display important disclaimer.""" + print( + "The PI0Fast model is a direct port of the OpenPI implementation. \n" + "This implementation follows the original OpenPI structure for compatibility. \n" + "Original implementation: https://github.com/Physical-Intelligence/openpi" + ) + if pretrained_name_or_path is None: + raise ValueError("pretrained_name_or_path is required") + + # Use provided config if available, otherwise create default config + if config is None: + config = PreTrainedConfig.from_pretrained( + pretrained_name_or_path=pretrained_name_or_path, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + token=token, + cache_dir=cache_dir, + local_files_only=local_files_only, + revision=revision, + **kwargs, + ) + + # Initialize model without loading weights + # Check if dataset_stats were provided in kwargs + model = cls(config, **kwargs) + + # Now manually load and remap the state dict + try: + # Try to load the pytorch_model.bin or model.safetensors file + print(f"Loading model from: {pretrained_name_or_path}") + try: + from transformers.utils import cached_file + + # Try safetensors first + resolved_file = cached_file( + pretrained_name_or_path, + "model.safetensors", + cache_dir=kwargs.get("cache_dir"), + force_download=kwargs.get("force_download", False), + resume_download=kwargs.get("resume_download"), + proxies=kwargs.get("proxies"), + use_auth_token=kwargs.get("use_auth_token"), + revision=kwargs.get("revision"), + local_files_only=kwargs.get("local_files_only", False), + ) + from safetensors.torch import load_file + + original_state_dict = load_file(resolved_file) + print("✓ Loaded state dict from model.safetensors") + except Exception as e: + print(f"Could not load state dict from remote files: {e}") + print("Returning model without loading pretrained weights") + return model + + # First, fix any key differences # see openpi `model.py, _fix_pytorch_state_dict_keys` + fixed_state_dict = model._fix_pytorch_state_dict_keys(original_state_dict, model.config) + # Then add "model." prefix for all keys that don't already have it + remapped_state_dict = {} + remap_count = 0 + + for key, value in fixed_state_dict.items(): + if not key.startswith("model."): + new_key = f"model.{key}" + remapped_state_dict[new_key] = value + remap_count += 1 + if remap_count <= 10: # Only print first 10 to avoid spam + print(f"Remapped: {key} -> {new_key}") + else: + remapped_state_dict[key] = value + + if remap_count > 0: + print(f"Remapped {remap_count} state dict keys") + + # Load the remapped state dict into the model + missing_keys, unexpected_keys = model.load_state_dict(remapped_state_dict, strict=strict) + + if missing_keys: + print(f"Missing keys when loading state dict: {len(missing_keys)} keys") + if len(missing_keys) <= 5: + for key in missing_keys: + print(f" - {key}") + else: + for key in missing_keys[:5]: + print(f" - {key}") + print(f" ... and {len(missing_keys) - 5} more") + + if unexpected_keys: + print(f"Unexpected keys when loading state dict: {len(unexpected_keys)} keys") + if len(unexpected_keys) <= 5: + for key in unexpected_keys: + print(f" - {key}") + else: + for key in unexpected_keys[:5]: + print(f" - {key}") + print(f" ... and {len(unexpected_keys) - 5} more") + + if not missing_keys and not unexpected_keys: + print("All keys loaded successfully!") + + except Exception as e: + print(f"Warning: Could not remap state dict keys: {e}") + + return model + + def _fix_pytorch_state_dict_keys( + self, state_dict, model_config + ): # see openpi `BaseModelConfig, _fix_pytorch_state_dict_keys` + """Fix state dict keys to match current model architecture.""" + + fixed_state_dict = {} + + for key, value in state_dict.items(): + new_key = key + + # Handle vision tower embedding layer potential differences + if "patch_embedding" in key: + # Some checkpoints might have this, but current model expects different structure + logging.warning(f"Vision embedding key might need handling: {key}") + + if ( + key == "model.paligemma_with_expert.paligemma.lm_head.weight" + or key == "paligemma_with_expert.paligemma.lm_head.weight" + ): + fixed_state_dict[ + "model.paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight" + ] = value.clone() + + fixed_state_dict[new_key] = value + + return fixed_state_dict + + def get_optim_params(self) -> dict: + return self.parameters() + + def reset(self): + """Reset internal state - called when environment resets.""" + self._action_queue = deque(maxlen=self.config.n_action_steps) + self._queues = { + ACTION: deque(maxlen=self.config.n_action_steps), + } + + def init_rtc_processor(self): + """Initialize RTC processor if RTC is enabled in config.""" + self.rtc_processor = None + + # Create processor if config provided + # If RTC is not enabled - we can still track the denoising data + if self.config.rtc_config is not None: + self.rtc_processor = RTCProcessor(self.config.rtc_config) + + model_value = getattr(self, "model", None) + if model_value is not None: + model_value.rtc_processor = self.rtc_processor + + def _rtc_enabled(self) -> bool: + return self.config.rtc_config is not None and self.config.rtc_config.enabled + + def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]: + """Preprocess images for the model. + + Images from LeRobot are typically in [B, C, H, W] format and normalized to [0, 1]. + PaliGemma expects images in [B, C, H, W] format and normalized to [-1, 1]. + """ + images = [] + img_masks = [] + + # Get device from model parameters + device = next(self.parameters()).device + + present_img_keys = [key for key in self.config.image_features if key in batch] + missing_img_keys = [key for key in self.config.image_features if key not in batch] + + if len(present_img_keys) == 0: + raise ValueError( + f"All image features are missing from the batch. At least one expected. " + f"(batch: {batch.keys()}) (image_features: {self.config.image_features})" + ) + + # Preprocess image features present in the batch + for key in present_img_keys: + img = batch[key] + + # Ensure tensor is on the same device as the model + if img.device != device: + img = img.to(device) + + # Ensure float32 dtype for consistency + if img.dtype != torch.float32: + img = img.to(torch.float32) + + # from openpi preprocess_observation_pytorch: Handle both [B, C, H, W] and [B, H, W, C] formats + is_channels_first = img.shape[1] == 3 # Check if channels are in dimension 1 + + if is_channels_first: + # Convert [B, C, H, W] to [B, H, W, C] for processing + img = img.permute(0, 2, 3, 1) + + # from openpi preprocess_observation_pytorch: Resize with padding if needed + if img.shape[1:3] != self.config.image_resolution: + img = resize_with_pad_torch(img, *self.config.image_resolution) + + # Normalize from [0,1] to [-1,1] as expected by siglip + img = img * 2.0 - 1.0 + + # from openpi preprocess_observation_pytorch: Convert back to [B, C, H, W] format if it was originally channels-first + if is_channels_first: + img = img.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W] + + images.append(img) + # Create mask (all ones for real images) + bsize = img.shape[0] + mask = torch.ones(bsize, dtype=torch.bool, device=device) + img_masks.append(mask) + + # Create image features not present in the batch as fully 0 padded images + for _num_empty_cameras in range(len(missing_img_keys)): + img = torch.ones_like(img) * -1 # Padded with -1 for SigLIP + mask = torch.zeros_like(mask) # Mask is zero for empty cameras + images.append(img) + img_masks.append(mask) + + return images, img_masks + + def prepare_action(self, batch): + """Pad action""" + actions = pad_vector(batch[ACTION], self.config.max_action_dim) + return actions + + def _paligemma_tokens_to_act_tokens(self, tokens: torch.Tensor) -> torch.Tensor: + """ + Converts PaliGemma tokens back to action tokens (inverse of _act_tokens_to_paligemma_tokens). + + Args: + tokens: PaliGemma token IDs + + Returns: + Action token IDs + """ + return self._paligemma_tokenizer.vocab_size - 1 - self.config.fast_skip_tokens - tokens + + def decode_actions_with_fast( + self, token_ids: list[int], time_horizon: int, action_dim: int, relaxed_decoding: bool = True + ) -> np.ndarray: + """ + Decodes action token IDs back to continuous action values using the FAST tokenizer. + + Args: + token_ids: List of token IDs to decode. + time_horizon: The number of timesteps for actions. + action_dim: The dimensionality of each action. + relaxed_decoding: Whether to use relaxed decoding (allows partial sequences). + + Returns: + A numpy array representing the decoded actions. + """ + decoded_actions = [] + + for token in token_ids: + try: + decoded_tokens = self.action_tokenizer.bpe_tokenizer.decode(token) + decoded_dct_coeff = np.array(list(map(ord, decoded_tokens))) + self.action_tokenizer.min_token + + if relaxed_decoding: + # expected sequence length + expected_seq_len = time_horizon * action_dim + diff = expected_seq_len - decoded_dct_coeff.shape[0] + + # apply truncation if too long + if diff < 0: + decoded_dct_coeff = decoded_dct_coeff[:expected_seq_len] # truncate on the right + + # apply padding if too short + elif diff > 0: + decoded_dct_coeff = np.pad( + decoded_dct_coeff, (0, diff), mode="constant", constant_values=0 + ) + + decoded_dct_coeff = decoded_dct_coeff.reshape(-1, action_dim) + assert decoded_dct_coeff.shape == ( + time_horizon, + action_dim, + ), ( + f"Decoded DCT coefficients have shape {decoded_dct_coeff.shape}, expected ({time_horizon}, {action_dim})" + ) + + except Exception as e: + logging.warning(f"Error decoding tokens: {e}") + logging.warning(f"Tokens: {token}") + decoded_dct_coeff = np.zeros((time_horizon, action_dim)) + + decoded_actions.append( + idct(decoded_dct_coeff / self.action_tokenizer.scale, axis=0, norm="ortho") + ) + + return np.stack(decoded_actions) + + def detokenize_actions(self, tokens: torch.Tensor, action_horizon: int, action_dim: int) -> torch.Tensor: + """ + Detokenizes action tokens back to continuous actions. + + This method converts predicted action tokens from the model back to continuous action values + using the FAST tokenizer. It handles the conversion from PaliGemma token space to action token + space, then decodes the action tokens to continuous values using DCT decoding. + + Args: + tokens: The input tensor of tokenized outputs. Shape: (B, seq_len) or (seq_len,) + action_horizon: The number of timesteps for actions. + action_dim: The dimensionality of each action. + + Returns: + The continuous action tensor. Shape: (B, action_horizon, action_dim) or (action_horizon, action_dim) + """ + if self.action_tokenizer is None or self._paligemma_tokenizer is None: + raise ValueError( + "Action tokenizer not initialized. Make sure fast_only=True in config and tokenizers loaded successfully." + ) + + # Handle single sample (add batch dimension) + single_sample = tokens.dim() == 1 + if single_sample: + tokens = tokens.unsqueeze(0) + + # Convert token IDs to token strings + decoded_tokens = [self._paligemma_tokenizer.convert_ids_to_tokens(seq.tolist()) for seq in tokens] + # Get the token sequence for "Action: " to remove it + action_prefix_ids = self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False) + action_prefix_tokens = self._paligemma_tokenizer.convert_ids_to_tokens(action_prefix_ids) + action_prefix_len = len(action_prefix_tokens) + + # Clean tokens by removing everything after the first "|" (end-of-action marker) + # and removing all occurrences of "Action: " token sequence + # assert that beginning contain "Action: " + if self.config.validate_action_token_prefix: + for token_seq in decoded_tokens: + assert len(token_seq) >= 2 and token_seq[0] == "Action" and token_seq[1] == ":", ( + f"Token sequence does not start with ['Action', ':']: {token_seq}" + ) + + cleaned_tokens = [] + for token_seq in decoded_tokens: + # Remove everything after "|" + if "|" in token_seq: + token_seq = token_seq[: token_seq.index("|")] + + # Remove all occurrences of "Action: " token sequence + i = 0 + while i <= len(token_seq) - action_prefix_len: + if token_seq[i : i + action_prefix_len] == action_prefix_tokens: + # Found a match, remove it + token_seq = token_seq[:i] + token_seq[i + action_prefix_len :] + else: + i += 1 + + cleaned_tokens.append(token_seq) + + # Convert token strings back to IDs + raw_action_tokens = [ + torch.tensor( + self._paligemma_tokenizer.convert_tokens_to_ids(token_seq), + dtype=torch.long, + device=tokens.device, + ) + for token_seq in cleaned_tokens + ] + + # Convert PaliGemma tokens to action tokens + action_tokens = [ + self._paligemma_tokens_to_act_tokens(raw_action_token) for raw_action_token in raw_action_tokens + ] + + # Decode action tokens to continuous actions + actions = self.decode_actions_with_fast( + action_tokens, time_horizon=action_horizon, action_dim=action_dim + ) + + # Convert to tensor and return + actions_tensor = torch.tensor(actions, dtype=torch.float32, device=tokens.device) + + # Remove batch dimension if input was single sample + if single_sample: + actions_tensor = actions_tensor.squeeze(0) + + return actions_tensor + + @torch.no_grad() + def select_action(self, batch: dict[str, Tensor]) -> Tensor: + """Select a single action given environment observations.""" + assert not self._rtc_enabled(), ( + "RTC is not supported for select_action, use it with predict_action_chunk" + ) + + self.eval() + + # Action queue logic for n_action_steps > 1 + if len(self._action_queue) == 0: + actions = self.predict_action_chunk(batch)[:, : self.config.n_action_steps] + # Transpose to get shape (n_action_steps, batch_size, action_dim) + self._action_queue.extend(actions.transpose(0, 1)) + + return self._action_queue.popleft() + + @torch.no_grad() + def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor: + """Predict a chunk of actions given environment observations.""" + self.eval() + # Prepare inputs + images, img_masks = self._preprocess_images(batch) + + # FAST-only mode: use autoregressive decoding + tokens = batch[f"{OBS_LANGUAGE_TOKENS}"] + masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] + + # Get decoding parameters + temperature = self.config.temperature + max_decoding_steps = self.config.max_decoding_steps + + # Sample action tokens autoregressively + if self.config.use_kv_cache: + action_tokens = self.model.sample_actions_fast_kv_cache( + images, + img_masks, + tokens, + masks, + max_decoding_steps=max_decoding_steps, + temperature=temperature, + ) + else: + action_tokens = self.model.sample_actions_fast( + images, + img_masks, + tokens, + masks, + max_decoding_steps=max_decoding_steps, + temperature=temperature, + ) + + # Detokenize action tokens to continuous actions + action_horizon = self.config.n_action_steps + action_dim = self.config.output_features[ACTION].shape[0] + + continuous_actions = self.detokenize_actions( + action_tokens, action_horizon=action_horizon, action_dim=action_dim + ) + + return continuous_actions + + def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]: + """Run the batch through the model and compute the loss for training.""" + + # Prepare inputs + images, img_masks = self._preprocess_images(batch) + + # Get FAST action tokens from batch + fast_action_tokens = batch.get(ACTION_TOKENS) # (B, max_action_tokens) + fast_action_masks = batch.get(ACTION_TOKEN_MASK) # (B, max_action_tokens) + + # Use full language tokens (no separation into high_level_task and subtask) + tokens = batch.get(OBS_LANGUAGE_TOKENS) + masks = batch.get(OBS_LANGUAGE_ATTENTION_MASK) + + if fast_action_tokens is None or fast_action_masks is None: + raise ValueError( + f"PI0Fast requires {ACTION_TOKENS} and {ACTION_TOKEN_MASK} to be present in the batch" + ) + + loss_dict = self.model.forward( + images, + img_masks, + tokens, + masks, + fast_action_tokens, + fast_action_masks, + ) + + loss = loss_dict["loss"] + detailed_loss_dict = { + "loss": loss.item(), + "ce_loss": loss_dict["ce_loss"].item(), + } + return loss, detailed_loss_dict diff --git a/src/lerobot/policies/pi0_fast/processor_pi0_fast.py b/src/lerobot/policies/pi0_fast/processor_pi0_fast.py new file mode 100644 index 000000000..0d9dac673 --- /dev/null +++ b/src/lerobot/policies/pi0_fast/processor_pi0_fast.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python + +# Copyright 2025 Physical Intelligence and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from dataclasses import dataclass +from typing import Any + +import numpy as np +import torch + +from lerobot.configs.types import PipelineFeatureType, PolicyFeature +from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig +from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector +from lerobot.processor import ( + ActionTokenizerProcessorStep, + AddBatchDimensionProcessorStep, + DeviceProcessorStep, + NormalizerProcessorStep, + PolicyAction, + PolicyProcessorPipeline, + ProcessorStep, + ProcessorStepRegistry, + RenameObservationsProcessorStep, + TokenizerProcessorStep, + UnnormalizerProcessorStep, +) +from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action +from lerobot.processor.core import EnvTransition, TransitionKey +from lerobot.utils.constants import ( + OBS_STATE, + POLICY_POSTPROCESSOR_DEFAULT_NAME, + POLICY_PREPROCESSOR_DEFAULT_NAME, +) + + +@ProcessorStepRegistry.register(name="pi0_fast_prepare_state_tokenizer_processor_step") +@dataclass +class Pi0FastPrepareStateAndLanguageTokenizerProcessorStep(ProcessorStep): + """ + Processor step to prepare the state and tokenize the language input. + """ + + max_state_dim: int = 32 + task_key: str = "task" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + transition = transition.copy() + + state = transition.get(TransitionKey.OBSERVATION, {}).get(OBS_STATE) + if state is None: + raise ValueError("State is required for PI0Fast") + tasks = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}).get(self.task_key) + if tasks is None: + raise ValueError("No task found in complementary data") + + # TODO: check if this necessary + state = deepcopy(state) + + # Prepare state (pad to max_state_dim) + state = pad_vector(state, self.max_state_dim) + + # State should already be normalized to [-1, 1] by the NormalizerProcessorStep that runs before this step + # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) + state_np = state.cpu().numpy() + discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1 + + full_prompts = [] + for i, task in enumerate(tasks): + cleaned_text = task.strip().replace("_", " ").replace("\n", " ") + state_str = " ".join(map(str, discretized_states[i])) + full_prompt = f"Task: {cleaned_text}, State: {state_str};\n" + full_prompts.append(full_prompt) + + transition[TransitionKey.COMPLEMENTARY_DATA][self.task_key] = full_prompts + # Normalize state to [-1, 1] range if needed (assuming it's already normalized by normalizer processor step!!) + # Discretize into 256 bins (see openpi `PaligemmaTokenizer.tokenize()`) + return transition + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + This step does not alter the feature definitions. + """ + return features + + +def make_pi0_fast_pre_post_processors( + config: PI0FastConfig, + dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None, +) -> tuple[ + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """ + Constructs pre-processor and post-processor pipelines for the PI0Fast policy. + + The pre-processing pipeline prepares input data for the model by: + 1. Renaming features to match pretrained configurations. + 2. Normalizing input and output features based on dataset statistics. + 3. Adding a batch dimension. + 4. Appending a newline character to the task description for tokenizer compatibility. + 5. Tokenizing the text prompt using the PaliGemma tokenizer. + 6. Moving all data to the specified device. + + The post-processing pipeline handles the model's output by: + 1. Moving data to the CPU. + 2. Unnormalizing the output features to their original scale. + + Args: + config: The configuration object for the PI0Fast policy. + dataset_stats: A dictionary of statistics for normalization. + preprocessor_kwargs: Additional arguments for the pre-processor pipeline. + postprocessor_kwargs: Additional arguments for the post-processor pipeline. + + Returns: + A tuple containing the configured pre-processor and post-processor pipelines. + """ + # Add remaining processors + input_steps: list[ProcessorStep] = [ + RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one + AddBatchDimensionProcessorStep(), + # NOTE: NormalizerProcessorStep MUST come before Pi0FastPrepareStateAndLanguageTokenizerProcessorStep + # because the tokenizer step expects normalized state in [-1, 1] range for discretization + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), + Pi0FastPrepareStateAndLanguageTokenizerProcessorStep(max_state_dim=config.max_state_dim), + TokenizerProcessorStep( + tokenizer_name=config.text_tokenizer_name, + max_length=config.tokenizer_max_length, + padding_side="right", + padding="max_length", + ), + ActionTokenizerProcessorStep( + action_tokenizer_name=config.action_tokenizer_name, + max_action_tokens=config.max_action_tokens, + fast_skip_tokens=config.fast_skip_tokens, + paligemma_tokenizer_name=config.text_tokenizer_name, + ), + DeviceProcessorStep(device=config.device), + ] + + output_steps: list[ProcessorStep] = [ + UnnormalizerProcessorStep( + features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats + ), + DeviceProcessorStep(device="cpu"), + ] + + return ( + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]]( + steps=input_steps, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, + ), + PolicyProcessorPipeline[PolicyAction, PolicyAction]( + steps=output_steps, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ), + ) diff --git a/src/lerobot/policies/pi0_fast/train_fast_tokenizer.py b/src/lerobot/policies/pi0_fast/train_fast_tokenizer.py new file mode 100644 index 000000000..6a3a1fe69 --- /dev/null +++ b/src/lerobot/policies/pi0_fast/train_fast_tokenizer.py @@ -0,0 +1,539 @@ +"""Train FAST tokenizer for action encoding. + +This script: +1. Loads action chunks from LeRobotDataset (with sampling) +2. Applies delta transforms and per-timestamp normalization +3. Trains FAST tokenizer on specified action dimensions +4. Saves tokenizer to assets directory +5. Reports compression statistics +""" + +import json +from pathlib import Path + +import numpy as np +import torch +import tyro +from huggingface_hub import HfApi +from transformers import AutoProcessor + +from lerobot.configs.types import NormalizationMode +from lerobot.datasets.lerobot_dataset import LeRobotDataset + + +def apply_delta_transform(state: np.ndarray, actions: np.ndarray, delta_dims: list[int] | None) -> np.ndarray: + """Apply delta transform to specified dimensions. + + Args: + state: Current state [D] + actions: Future actions [D] + delta_dims: List of dimension indices to apply delta transform to + + Returns: + Transformed actions [D] + """ + if delta_dims is None or len(delta_dims) == 0: + return actions + + delta_actions = actions.copy() + for dim in delta_dims: + delta_actions[dim] = actions[dim] - state[dim] + + return delta_actions + + +def apply_normalization( + data: np.ndarray, + stats: dict[str, np.ndarray], + mode: NormalizationMode, + eps: float = 1e-8, +) -> np.ndarray: + """Apply normalization to data based on the specified mode. + + Args: + data: Data to normalize [N, H, D] or [D] + stats: Dictionary of statistics (mean, std, min, max, q01, q99, q10, q90) + mode: Normalization mode to apply + eps: Small epsilon for numerical stability + + Returns: + Normalized data with the same shape as input + """ + if mode == NormalizationMode.IDENTITY: + return data + + if mode == NormalizationMode.MEAN_STD: + mean = stats.get("mean") + std = stats.get("std") + if mean is None or std is None: + raise ValueError("MEAN_STD mode requires 'mean' and 'std' in stats") + return (data - mean) / np.maximum(std, eps) + + if mode == NormalizationMode.MIN_MAX: + min_val = stats.get("min") + max_val = stats.get("max") + if min_val is None or max_val is None: + raise ValueError("MIN_MAX mode requires 'min' and 'max' in stats") + denom = np.maximum(max_val - min_val, eps) + return 2.0 * (data - min_val) / denom - 1.0 + + if mode == NormalizationMode.QUANTILES: + q01 = stats.get("q01") + q99 = stats.get("q99") + if q01 is None or q99 is None: + raise ValueError("QUANTILES mode requires 'q01' and 'q99' in stats") + denom = np.maximum(q99 - q01, eps) + # Clip to quantile range then normalize to [-1, 1] + clipped = np.clip(data, q01, q99) + return 2.0 * (clipped - q01) / denom - 1.0 + + if mode == NormalizationMode.QUANTILE10: + q10 = stats.get("q10") + q90 = stats.get("q90") + if q10 is None or q90 is None: + raise ValueError("QUANTILE10 mode requires 'q10' and 'q90' in stats") + denom = np.maximum(q90 - q10, eps) + # Clip to quantile range then normalize to [-1, 1] + clipped = np.clip(data, q10, q90) + return 2.0 * (clipped - q10) / denom - 1.0 + + raise ValueError(f"Unsupported normalization mode: {mode}") + + +def process_episode(args): + """Process single episode and return action chunks.""" + dataset, ep_idx, action_horizon, delta_dims, sample_fraction, state_key, use_delta_transform = args + + try: + # get episode info + ep_info = dataset.meta.episodes[ep_idx] + from_idx = ep_info["dataset_from_index"] + to_idx = ep_info["dataset_to_index"] + ep_length = to_idx - from_idx + + if ep_length < action_horizon: + return None + + # load all frames in episode + # if dataset has episode filtering, we need to use the mapping + states = [] + actions = [] + + for abs_idx in range(from_idx, to_idx): + # map absolute index to relative index if needed + if dataset._absolute_to_relative_idx is not None: + if abs_idx not in dataset._absolute_to_relative_idx: + # this episode's frames aren't in the filtered dataset + return None + rel_idx = dataset._absolute_to_relative_idx[abs_idx] + else: + rel_idx = abs_idx + + frame = dataset.hf_dataset[rel_idx] + + # get state (could be from observation.state or other state key) + if state_key in frame: + state = ( + frame[state_key].numpy() + if torch.is_tensor(frame[state_key]) + else np.array(frame[state_key]) + ) + else: + # if no state key, use zeros (no delta transform) + state = np.zeros_like( + frame["action"].numpy() if torch.is_tensor(frame["action"]) else np.array(frame["action"]) + ) + + action = ( + frame["action"].numpy() if torch.is_tensor(frame["action"]) else np.array(frame["action"]) + ) + + states.append(state) + actions.append(action) + + states = np.array(states) + actions = np.array(actions) + + # create action chunks (sliding window) + # all actions in a chunk are relative to the FIRST state in that chunk + action_chunks = [] + + for i in range(len(states) - action_horizon + 1): + current_state = states[i] # First state in chunk + future_absolute_actions = actions[i : i + action_horizon] + + if use_delta_transform: + # relative actions + delta_chunk = np.zeros_like(future_absolute_actions) + for t in range(action_horizon): + delta_chunk[t] = apply_delta_transform( + current_state, + future_absolute_actions[t], + delta_dims, + ) + action_chunks.append(delta_chunk) + else: + # absolute actions (no delta) + action_chunks.append(future_absolute_actions) + + if len(action_chunks) == 0: + return None + + action_chunks = np.array(action_chunks) + + # sample chunks + if sample_fraction < 1.0: + n_chunks = len(action_chunks) + n_samples = max(1, int(n_chunks * sample_fraction)) + episode_seed = hash(ep_idx) % (2**31) + rng = np.random.RandomState(episode_seed) + indices = rng.choice(n_chunks, size=n_samples, replace=False) + action_chunks = action_chunks[indices] + + return action_chunks + + except Exception as e: + print(f"Error processing episode {ep_idx}: {e}") + import traceback + + traceback.print_exc() + return None + + +def train_fast_tokenizer( + action_chunks: np.ndarray, + vocab_size: int = 1024, + scale: float = 10.0, +) -> AutoProcessor: + """ + Train FAST tokenizer (BPE on DCT coefficients) on action chunks. + + Uses the .fit() method to train a new tokenizer on the provided data. + + Args: + action_chunks: Array of action chunks [N, H, D] where N=num_chunks, H=horizon, D=action_dim + vocab_size: BPE vocabulary size + scale: DCT scaling factor for quantization + + Returns: + Trained FAST tokenizer + """ + print(f"Training FAST tokenizer on {len(action_chunks)} action chunks...") + print(f"Action chunk shape: {action_chunks.shape}") + print(f"Vocab size: {vocab_size}") + print(f"DCT scale: {scale}") + + # download the tokenizer source code (not pretrained weights) + # we'll train a new tokenizer on our own data + base_tokenizer = AutoProcessor.from_pretrained("physical-intelligence/fast", trust_remote_code=True) + + # convert action_chunks array to list of arrays (expected by .fit()) + action_data_list = [action_chunks[i] for i in range(len(action_chunks))] + + # train the new tokenizer on our action data using .fit() + # this trains the BPE tokenizer on DCT coefficients + print("Training new tokenizer (this may take a few minutes)...") + tokenizer = base_tokenizer.fit( + action_data_list, + scale=scale, + vocab_size=vocab_size, + time_horizon=action_chunks.shape[1], # action_horizon + action_dim=action_chunks.shape[2], # encoded dimensions + ) + print("✓ Tokenizer training complete!") + + # validate it works + sample_chunk = action_chunks[0] + encoded = tokenizer(sample_chunk[None])[0] + if isinstance(encoded, list): + encoded = np.array(encoded) + print(f"Sample encoding: {len(encoded)} tokens for chunk shape {sample_chunk.shape}") + + return tokenizer + + +def compute_compression_stats(tokenizer, action_chunks: np.ndarray): + """Compute compression statistics.""" + print("\nComputing compression statistics...") + + # sample for stats (use max 1000 chunks for speed) + sample_size = min(1000, len(action_chunks)) + sample_indices = np.random.RandomState(42).choice(len(action_chunks), size=sample_size, replace=False) + sample_chunks = action_chunks[sample_indices] + + token_lengths = [] + for chunk in sample_chunks: + encoded = tokenizer(chunk[None])[0] + if isinstance(encoded, list): + token_lengths.append(len(encoded)) + else: + token_lengths.append(encoded.shape[0] if hasattr(encoded, "shape") else len(encoded)) + + token_lengths = np.array(token_lengths) + + # compression ratio: (H * D) / avg_tokens + input_size = action_chunks.shape[1] * action_chunks.shape[2] + avg_tokens = np.mean(token_lengths) + compression_ratio = input_size / avg_tokens + + stats = { + "compression_ratio": float(compression_ratio), + "mean_token_length": float(np.mean(token_lengths)), + "p99_token_length": float(np.percentile(token_lengths, 99)), + "min_token_length": float(np.min(token_lengths)), + "max_token_length": float(np.max(token_lengths)), + } + + print("Compression Statistics:") + print(f" Average compression ratio: {stats['compression_ratio']:.2f}x") + print(f" Mean token length: {stats['mean_token_length']:.1f}") + print(f" P99 token length: {stats['p99_token_length']:.0f}") + print(f" Min token length: {stats['min_token_length']:.0f}") + print(f" Max token length: {stats['max_token_length']:.0f}") + + return stats + + +def main( + repo_id: str, + root: str | None = None, + action_horizon: int = 10, + max_episodes: int | None = None, + sample_fraction: float = 0.1, + encoded_dims: str = "0:6,7:23", + delta_dims: str | None = None, + use_delta_transform: bool = False, + state_key: str = "observation.state", + normalization_mode: str = "QUANTILES", + vocab_size: int = 1024, + scale: float = 10.0, + output_dir: str | None = None, + push_to_hub: bool = False, + hub_repo_id: str | None = None, + hub_private: bool = False, +): + """ + Train FAST tokenizer for action encoding. + + Args: + repo_id: LeRobot dataset repository ID + root: Root directory for dataset (default: ~/.cache/huggingface/lerobot) + action_horizon: Number of future actions in each chunk + max_episodes: Max episodes to use (None = all episodes in dataset) + sample_fraction: Fraction of chunks to sample per episode + encoded_dims: Comma-separated dimension ranges to encode (e.g., "0:6,7:23") + delta_dims: Comma-separated dimension indices for delta transform (e.g., "0,1,2,3,4,5") + use_delta_transform: Whether to apply delta transform (relative actions vs absolute actions) + state_key: Dataset key for state observations (default: "observation.state") + normalization_mode: Normalization mode (MEAN_STD, MIN_MAX, QUANTILES, QUANTILE10, IDENTITY) + vocab_size: FAST vocabulary size (BPE vocab size) + scale: DCT scaling factor (default: 10.0) + output_dir: Directory to save tokenizer (default: ./fast_tokenizer_{repo_id}) + push_to_hub: Whether to push the tokenizer to Hugging Face Hub + hub_repo_id: Hub repository ID (e.g., "username/tokenizer-name"). If None, uses output_dir name + hub_private: Whether to create a private repository on the Hub + """ + # load dataset + print(f"Loading dataset: {repo_id}") + dataset = LeRobotDataset(repo_id=repo_id, root=root) + print(f"Dataset loaded: {dataset.num_episodes} episodes, {dataset.num_frames} frames") + + # parse normalization mode + try: + norm_mode = NormalizationMode(normalization_mode) + except ValueError as err: + raise ValueError( + f"Invalid normalization_mode: {normalization_mode}. " + f"Must be one of: {', '.join([m.value for m in NormalizationMode])}" + ) from err + print(f"Normalization mode: {norm_mode.value}") + + # parse encoded dimensions + encoded_dim_ranges = [] + for range_str in encoded_dims.split(","): + start, end = map(int, range_str.strip().split(":")) + encoded_dim_ranges.append((start, end)) + + total_encoded_dims = sum(end - start for start, end in encoded_dim_ranges) + print(f"Encoding {total_encoded_dims} dimensions: {encoded_dims}") + + # parse delta dimensions + delta_dim_list = None + if delta_dims is not None and delta_dims.strip(): + delta_dim_list = [int(d.strip()) for d in delta_dims.split(",")] + print(f"Delta dimensions: {delta_dim_list}") + else: + print("No delta dimensions specified") + + print(f"Use delta transform: {use_delta_transform}") + if use_delta_transform and (delta_dim_list is None or len(delta_dim_list) == 0): + print("Warning: use_delta_transform=True but no delta_dims specified. No delta will be applied.") + + print(f"Action horizon: {action_horizon}") + print(f"State key: {state_key}") + + # determine episodes to process + num_episodes = dataset.num_episodes + if max_episodes is not None: + num_episodes = min(max_episodes, num_episodes) + + print(f"Processing {num_episodes} episodes...") + + # process episodes sequentially (to avoid pickling issues with dataset) + all_chunks = [] + for ep_idx in range(num_episodes): + if ep_idx % 10 == 0: + print(f" Processing episode {ep_idx}/{num_episodes}...") + + chunks = process_episode( + (dataset, ep_idx, action_horizon, delta_dim_list, sample_fraction, state_key, use_delta_transform) + ) + if chunks is not None: + all_chunks.append(chunks) + + # concatenate all chunks + all_chunks = np.concatenate(all_chunks, axis=0) + print(f"Collected {len(all_chunks)} action chunks") + + # extract only encoded dimensions FIRST (before normalization) + encoded_chunks = [] + for start, end in encoded_dim_ranges: + encoded_chunks.append(all_chunks[:, :, start:end]) + encoded_chunks = np.concatenate(encoded_chunks, axis=-1) # [N, H, D_encoded] + print(f"Extracted {encoded_chunks.shape[-1]} encoded dimensions") + + # apply normalization to encoded dimensions + print("\nBefore normalization - overall stats:") + print(f" Min: {np.min(encoded_chunks):.4f}, Max: {np.max(encoded_chunks):.4f}") + print(f" Mean: {np.mean(encoded_chunks):.4f}, Std: {np.std(encoded_chunks):.4f}") + + # get normalization stats from dataset + norm_stats = dataset.meta.stats + if norm_stats is not None and "action" in norm_stats: + action_stats = norm_stats["action"] + + # build encoded dimension indices + encoded_dim_indices = [] + for start, end in encoded_dim_ranges: + encoded_dim_indices.extend(range(start, end)) + encoded_dim_indices = np.array(encoded_dim_indices) + + # extract stats for encoded dimensions only + encoded_stats = {} + for stat_name, stat_values in action_stats.items(): + if isinstance(stat_values, (list, np.ndarray)): + stat_array = np.array(stat_values) + if len(stat_array) > max(encoded_dim_indices): + encoded_stats[stat_name] = stat_array[encoded_dim_indices] + + if encoded_stats: + print(f"\nNormalization stats for encoded dimensions (mode: {norm_mode.value}):") + for stat_name, stat_values in encoded_stats.items(): + print( + f" {stat_name}: shape={stat_values.shape}, " + f"range=[{np.min(stat_values):.4f}, {np.max(stat_values):.4f}]" + ) + + # apply normalization based on mode + try: + encoded_chunks = apply_normalization(encoded_chunks, encoded_stats, norm_mode, eps=1e-8) + print(f"\nApplied {norm_mode.value} normalization") + except ValueError as e: + print(f"Warning: {e}. Using raw actions without normalization.") + + print("\nAfter normalization - overall stats:") + print(f" Min: {np.min(encoded_chunks):.4f}, Max: {np.max(encoded_chunks):.4f}") + print(f" Mean: {np.mean(encoded_chunks):.4f}, Std: {np.std(encoded_chunks):.4f}") + + print("\nPer-dimension stats (after normalization):") + for d in range(encoded_chunks.shape[-1]): + dim_data = encoded_chunks[:, :, d] + print( + f" Dim {d}: min={np.min(dim_data):7.4f}, max={np.max(dim_data):7.4f}, " + f"mean={np.mean(dim_data):7.4f}, std={np.std(dim_data):7.4f}" + ) + else: + print("Warning: Could not extract stats for encoded dimensions, using raw actions") + else: + print("Warning: No normalization stats found in dataset, using raw actions") + + print(f"Encoded chunks shape: {encoded_chunks.shape}") + + # train FAST tokenizer + tokenizer = train_fast_tokenizer( + encoded_chunks, + vocab_size=vocab_size, + scale=scale, + ) + + # compute compression statistics + compression_stats = compute_compression_stats(tokenizer, encoded_chunks) + + # save tokenizer + if output_dir is None: + output_dir = f"fast_tokenizer_{repo_id.replace('/', '_')}" + output_path = Path(output_dir) + output_path.mkdir(parents=True, exist_ok=True) + + tokenizer.save_pretrained(output_path) + + # save metadata + metadata = { + "repo_id": repo_id, + "vocab_size": vocab_size, + "scale": scale, + "encoded_dims": encoded_dims, + "encoded_dim_ranges": encoded_dim_ranges, + "total_encoded_dims": total_encoded_dims, + "delta_dims": delta_dims, + "delta_dim_list": delta_dim_list, + "use_delta_transform": use_delta_transform, + "state_key": state_key, + "normalization_mode": norm_mode.value, + "action_horizon": action_horizon, + "num_training_chunks": len(encoded_chunks), + "compression_stats": compression_stats, + } + + with open(output_path / "metadata.json", "w") as f: + json.dump(metadata, f, indent=2) + + print(f"\nSaved FAST tokenizer to {output_path}") + print(f"Metadata: {json.dumps(metadata, indent=2)}") + + # push to Hugging Face Hub if requested + if push_to_hub: + # determine the hub repository ID + if hub_repo_id is None: + hub_repo_id = output_path.name + print(f"\nNo hub_repo_id provided, using: {hub_repo_id}") + + print(f"\nPushing tokenizer to Hugging Face Hub: {hub_repo_id}") + print(f" Private: {hub_private}") + + try: + # use the tokenizer's push_to_hub method + tokenizer.push_to_hub( + repo_id=hub_repo_id, + private=hub_private, + commit_message=f"Upload FAST tokenizer trained on {repo_id}", + ) + + # also upload the metadata.json file separately + api = HfApi() + api.upload_file( + path_or_fileobj=str(output_path / "metadata.json"), + path_in_repo="metadata.json", + repo_id=hub_repo_id, + repo_type="model", + commit_message="Upload tokenizer metadata", + ) + + print(f"Successfully pushed tokenizer to: https://huggingface.co/{hub_repo_id}") + except Exception as e: + print(f"Error pushing to hub: {e}") + print(" Make sure you're logged in with `huggingface-cli login`") + + +if __name__ == "__main__": + tyro.cli(main) diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index be11ac1af..676ba29ee 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -75,7 +75,7 @@ from .policy_robot_bridge import ( RobotActionToPolicyActionProcessorStep, ) from .rename_processor import RenameObservationsProcessorStep -from .tokenizer_processor import TokenizerProcessorStep +from .tokenizer_processor import ActionTokenizerProcessorStep, TokenizerProcessorStep __all__ = [ "ActionProcessorStep", @@ -122,6 +122,7 @@ __all__ = [ "AddBatchDimensionProcessorStep", "RobotProcessorPipeline", "TokenizerProcessorStep", + "ActionTokenizerProcessorStep", "Torch2NumpyActionProcessorStep", "RobotActionToPolicyActionProcessorStep", "PolicyActionToRobotActionProcessorStep", diff --git a/src/lerobot/processor/tokenizer_processor.py b/src/lerobot/processor/tokenizer_processor.py index 2ef89c107..93e0395b9 100644 --- a/src/lerobot/processor/tokenizer_processor.py +++ b/src/lerobot/processor/tokenizer_processor.py @@ -23,22 +23,29 @@ token IDs and attention masks, which are then added to the observation dictionar from __future__ import annotations +import logging from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any import torch from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature -from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS +from lerobot.utils.constants import ( + ACTION_TOKEN_MASK, + ACTION_TOKENS, + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, +) from lerobot.utils.import_utils import _transformers_available from .core import EnvTransition, TransitionKey -from .pipeline import ObservationProcessorStep, ProcessorStepRegistry +from .pipeline import ActionProcessorStep, ObservationProcessorStep, ProcessorStepRegistry # Conditional import for type checking and lazy loading if TYPE_CHECKING or _transformers_available: - from transformers import AutoTokenizer + from transformers import AutoProcessor, AutoTokenizer else: + AutoProcessor = None AutoTokenizer = None @@ -268,3 +275,256 @@ class TokenizerProcessorStep(ObservationProcessorStep): ) return features + + +@dataclass +@ProcessorStepRegistry.register(name="action_tokenizer_processor") +class ActionTokenizerProcessorStep(ActionProcessorStep): + """ + Processor step to tokenize action data using a fast action tokenizer. + + This step takes action tensors from an `EnvTransition`, tokenizes them using + a Hugging Face `transformers` AutoProcessor (such as the Physical Intelligence "fast" tokenizer), + and returns the tokenized action. + + Requires the `transformers` library to be installed. + + Attributes: + tokenizer_name: The name of a pretrained processor from the Hugging Face Hub (e.g., "physical-intelligence/fast"). + tokenizer: A pre-initialized processor/tokenizer object. If provided, `tokenizer_name` is ignored. + trust_remote_code: Whether to trust remote code when loading the tokenizer (required for some tokenizers). + action_tokenizer: The internal tokenizer/processor instance, loaded during initialization. + paligemma_tokenizer_name: The name of a pretrained PaliGemma tokenizer from the Hugging Face Hub (e.g., "google/paligemma-3b-pt-224"). + """ + + action_tokenizer_name: str | None = None + action_tokenizer_input_object: Any | None = None + trust_remote_code: bool = True + max_action_tokens: int = 256 + fast_skip_tokens: int = 128 + paligemma_tokenizer_name: str = "google/paligemma-3b-pt-224" + # Internal tokenizer instance (not part of the config) + action_tokenizer: Any = field(default=None, init=False, repr=False) + _paligemma_tokenizer: Any = field(default=None, init=False, repr=False) + + def __post_init__(self): + """ + Initializes the action tokenizer after the dataclass is created. + + It checks for the availability of the `transformers` library and loads the tokenizer + either from a provided object or by name from the Hugging Face Hub. + + Raises: + ImportError: If the `transformers` library is not installed. + ValueError: If neither `tokenizer` nor `tokenizer_name` is provided. + """ + if not _transformers_available: + raise ImportError( + "The 'transformers' library is not installed. " + "Please install it with `pip install 'lerobot[transformers-dep]'` to use ActionTokenizerProcessorStep." + ) + + if self.action_tokenizer_input_object is not None: + self.action_tokenizer = self.action_tokenizer_input_object + + elif self.action_tokenizer_name is not None: + if AutoProcessor is None: + raise ImportError("AutoProcessor is not available") + self.action_tokenizer = AutoProcessor.from_pretrained( + self.action_tokenizer_name, trust_remote_code=self.trust_remote_code + ) + else: + raise ValueError( + "Either 'action_tokenizer' or 'action_tokenizer_name' must be provided. " + "Pass a tokenizer object directly or a tokenizer name to auto-load." + ) + + self._paligemma_tokenizer = AutoTokenizer.from_pretrained( + self.paligemma_tokenizer_name, + trust_remote_code=self.trust_remote_code, + add_eos_token=True, + add_bos_token=False, + ) + + def __call__(self, transition: EnvTransition) -> EnvTransition: + """ + Applies action tokenization to the transition. + + This overrides the base class to handle both tokens and mask. + + Args: + transition: The input transition with action data. + + Returns: + The processed transition with tokenized actions and mask in complementary data. + """ + self._current_transition = transition.copy() + new_transition = self._current_transition + + action = new_transition.get(TransitionKey.ACTION) + if action is None: + # During inference, no action is available, skip tokenization + return new_transition + + # Tokenize and get both tokens and mask + tokens, mask = self._tokenize_action(action) + + # Store mask in complementary data + complementary_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) + if complementary_data is None: + complementary_data = {} + complementary_data[ACTION_TOKEN_MASK] = mask + complementary_data[ACTION_TOKENS] = tokens + new_transition[TransitionKey.COMPLEMENTARY_DATA] = complementary_data + return new_transition + + def _act_tokens_to_paligemma_tokens(self, tokens: torch.Tensor) -> torch.Tensor: + """ + Converts action tokens to PaliGemma tokens. + """ + return self._paligemma_tokenizer.vocab_size - 1 - self.fast_skip_tokens - tokens + + def _tokenize_action(self, action: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Tokenizes the action tensor and creates a mask. + + Args: + action: The input action tensor to tokenize. Shape: (B, H, action_dim) or (H, action_dim,) + + Returns: + A tuple of (tokens, mask) where: + - tokens: Tensor of token IDs with shape (B, max_action_tokens) + - mask: Boolean mask with shape (B, max_action_tokens), True for real tokens, False for padding + """ + if action is None: + raise ValueError("Action cannot be None") + + # Get the device and dtype of the input action + device = action.device if isinstance(action, torch.Tensor) else None + + # Handle single sample (add batch dimension) + single_sample = action.dim() == 1 + if single_sample: + action = action.unsqueeze(0) + + batch_size = action.shape[0] + + # Tokenize the action batch + # The fast tokenizer expects action data and returns token IDs + tokens_list = [] + masks_list = [] + + for i in range(batch_size): + # Tokenize single action (move to CPU first as tokenizer uses scipy which requires numpy) + action_cpu = action[i : i + 1].cpu() + tokens = self.action_tokenizer(action_cpu) + + # Convert to numpy array if it's a list + if isinstance(tokens, list) or not isinstance(tokens, torch.Tensor): + tokens = torch.tensor(tokens, dtype=torch.long, device=action.device) + else: + # Move tokens back to the same device as input action + tokens = tokens.to(device=action.device) + + # Flatten to 1D if needed + if tokens.dim() > 1: + tokens = tokens.flatten() + + bos_id = self._paligemma_tokenizer.bos_token_id + # add bos + tokens = torch.cat( + [ + torch.tensor([bos_id], device=action.device), + torch.tensor( + self._paligemma_tokenizer.encode("Action: ", add_special_tokens=False), + device=action.device, + ), + self._act_tokens_to_paligemma_tokens(tokens), + torch.tensor(self._paligemma_tokenizer.encode("|"), device=action.device), + ] + ) + + # Truncate or pad to max_action_tokens + if len(tokens) > self.max_action_tokens: + logging.warning( + f"Token length ({len(tokens)}) exceeds max length ({self.max_action_tokens}), truncating. " + "Consider increasing the `max_action_tokens` in your model config if this happens frequently." + ) + tokens = tokens[: self.max_action_tokens] + mask = torch.ones(self.max_action_tokens, dtype=torch.bool, device=action.device) + else: + mask = torch.cat( + [ + torch.ones(len(tokens), dtype=torch.bool, device=action.device), + torch.zeros( + self.max_action_tokens - len(tokens), dtype=torch.bool, device=action.device + ), + ] + ) + # Pad tokens with zeros + tokens = torch.nn.functional.pad(tokens, (0, self.max_action_tokens - len(tokens)), value=0) + + tokens_list.append(tokens) + masks_list.append(mask) + + # Stack into batched tensors + tokens_batch = torch.stack(tokens_list, dim=0) # (B, max_action_tokens) + masks_batch = torch.stack(masks_list, dim=0) # (B, max_action_tokens) + + # Remove batch dimension if input was single sample + if single_sample: + tokens_batch = tokens_batch.squeeze(0) + masks_batch = masks_batch.squeeze(0) + + # Move to the same device as the input + if device is not None: + tokens_batch = tokens_batch.to(device) + masks_batch = masks_batch.to(device) + + return tokens_batch, masks_batch + + def action(self, action: torch.Tensor) -> torch.Tensor: + """ + This method is not used since we override __call__. + Required by ActionProcessorStep ABC. + """ + tokens, _ = self._tokenize_action(action) + return tokens + + def get_config(self) -> dict[str, Any]: + """ + Returns the serializable configuration of the processor. + + Note: The tokenizer object itself is not serialized. If the processor was initialized + with a tokenizer name, that name will be included in the config. + + Returns: + A dictionary with the processor's configuration parameters. + """ + config = { + "trust_remote_code": self.trust_remote_code, + "max_action_tokens": self.max_action_tokens, + } + + # Only save tokenizer_name if it was used to create the tokenizer + if self.action_tokenizer_name is not None and self.action_tokenizer_input_object is None: + config["action_tokenizer_name"] = self.action_tokenizer_name + + return config + + def transform_features( + self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] + ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: + """ + Updates feature definitions to reflect tokenized actions. + + This updates the policy features dictionary to indicate that the action + has been tokenized into a sequence of token IDs with shape (max_action_tokens,). + + Args: + features: The dictionary of existing policy features. + + Returns: + The updated dictionary of policy features. + """ + return features diff --git a/src/lerobot/utils/constants.py b/src/lerobot/utils/constants.py index dfa10b2e5..a96e1596d 100644 --- a/src/lerobot/utils/constants.py +++ b/src/lerobot/utils/constants.py @@ -28,6 +28,8 @@ OBS_LANGUAGE_TOKENS = OBS_LANGUAGE + ".tokens" OBS_LANGUAGE_ATTENTION_MASK = OBS_LANGUAGE + ".attention_mask" ACTION = "action" +ACTION_TOKENS = ACTION + ".tokens" +ACTION_TOKEN_MASK = ACTION + ".token_mask" REWARD = "next.reward" TRUNCATED = "next.truncated" DONE = "next.done" diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index 3a01aee88..e6817ba6c 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -63,6 +63,7 @@ def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[b _transformers_available = is_package_available("transformers") _peft_available = is_package_available("peft") +_scipy_available = is_package_available("scipy") def make_device_from_device_class(config: ChoiceRegistry) -> Any: diff --git a/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py b/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py new file mode 100644 index 000000000..9ebc4ba89 --- /dev/null +++ b/tests/policies/pi0_fast/test_pi0_fast_original_vs_lerobot.py @@ -0,0 +1,504 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test script to verify PI0Fast policy integration with LeRobot vs the original implementation""" +# ruff: noqa: E402 + +import os +import random +from copy import deepcopy +from typing import Any + +import numpy as np +import pytest +import torch + +pytest.importorskip("transformers") +pytest.importorskip("scipy") +pytestmark = pytest.mark.skipif( + os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true", + reason="This test requires accepting the model license", +) + +from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig +from lerobot.policies.pi0_fast.modeling_pi0_fast import PI0FastPolicy +from lerobot.policies.pi0_fast.processor_pi0_fast import make_pi0_fast_pre_post_processors +from lerobot.processor import PolicyAction, PolicyProcessorPipeline # noqa: E402 +from lerobot.utils.constants import ( + ACTION_TOKEN_MASK, + ACTION_TOKENS, + OBS_IMAGES, + OBS_LANGUAGE_ATTENTION_MASK, + OBS_LANGUAGE_TOKENS, + OBS_STATE, +) # noqa: E402 +from tests.utils import require_cuda # noqa: E402 + +# Constants +DUMMY_ACTION_DIM = 7 +DUMMY_STATE_DIM = 20 +IMAGE_HEIGHT = 224 +IMAGE_WIDTH = 224 +NUM_VIEWS = 2 # Number of camera views +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" +MODEL_PATH_LEROBOT = "lerobot/pi0fast-base" + +# Expected action token shape: (batch_size, max_decoding_steps) +EXPECTED_ACTION_TOKENS_SHAPE = (1, 2) + +# Expected first 5 action tokens (for reproducibility check) +EXPECTED_ACTION_TOKENS_FIRST_5 = torch.tensor([255657, 255362]) + +# Expected actions after detokenization +EXPECTED_ACTIONS_SHAPE = (1, 2, 32) # (batch_size, n_action_steps, action_dim) +EXPECTED_ACTIONS_MEAN = 0.04419417306780815 +EXPECTED_ACTIONS_STD = 0.26231569051742554 +EXPECTED_ACTIONS_FIRST_5 = torch.tensor([0.0000, 1.4849, 0.0000, 0.0000, 0.0000]) + + +def set_seed_all(seed: int): + """Set random seed for all RNG sources to ensure reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + # Set deterministic behavior + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True, warn_only=True) + + +def instantiate_lerobot_pi0_fast( + from_pretrained: bool = False, + model_path: str = MODEL_PATH_LEROBOT, +) -> tuple[ + Any, # Policy + PolicyProcessorPipeline[dict[str, Any], dict[str, Any]], + PolicyProcessorPipeline[PolicyAction, PolicyAction], +]: + """Instantiate LeRobot PI0Fast policy with preprocessor and postprocessor.""" + if from_pretrained: + policy = PI0FastPolicy.from_pretrained( + pretrained_name_or_path=model_path, + strict=True, + ) + policy.config.validate_action_token_prefix = False + policy.config.max_action_tokens = 2 + policy.config.max_decoding_steps = 2 + policy.config.chunk_size = 2 + policy.config.n_action_steps = 2 + else: + config = PI0FastConfig( + n_action_steps=2, + max_action_dim=DUMMY_ACTION_DIM, + max_state_dim=DUMMY_STATE_DIM, + device=DEVICE, + validate_action_token_prefix=False, + max_action_tokens=2, + max_decoding_steps=2, + chunk_size=2, + ) + policy = PI0FastPolicy(config) + + policy.to(DEVICE) + policy.config.device = DEVICE + preprocessor, postprocessor = make_pi0_fast_pre_post_processors( + config=policy.config, + dataset_stats=None, # Pass None for dataset_stats to disable normalization + ) + + return policy, preprocessor, postprocessor + + +def create_dummy_data(device=DEVICE): + """Create dummy data for testing both implementations.""" + batch_size = 1 + prompt = "Pick up the red block and place it in the bin" + + # Create random RGB images in [0, 255] uint8 range (as PIL images would be) + # Then convert to [0, 1] float32 range for LeRobot + def fake_rgb(h, w): + arr = np.random.randint(0, 255, (h, w, 3), dtype=np.uint8) + t = torch.from_numpy(arr).permute(2, 0, 1) # CHW + return t + + batch = { + f"{OBS_IMAGES}.base_0_rgb": torch.stack( + [fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)] + ).to(device), + f"{OBS_IMAGES}.left_wrist_0_rgb": torch.stack( + [fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)] + ).to(device), + f"{OBS_IMAGES}.right_wrist_0_rgb": torch.stack( + [fake_rgb(IMAGE_HEIGHT, IMAGE_WIDTH) for _ in range(batch_size)] + ).to(device), + OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=device), + "task": [prompt for _ in range(batch_size)], + } + + return batch + + +# Pytest fixtures +@pytest.fixture(scope="module") +def pi0_fast_components(): + """Fixture to instantiate and provide all PI0Fast components for tests.""" + print(f"\nTesting with DEVICE='{DEVICE}'") + print("\n[Setup] Instantiating LeRobot PI0Fast policy...") + policy_obj, preprocessor_obj, postprocessor_obj = instantiate_lerobot_pi0_fast(from_pretrained=True) + print("Model loaded successfully") + yield policy_obj, preprocessor_obj, postprocessor_obj + + +@pytest.fixture(scope="module") +def policy(pi0_fast_components): + """Fixture to provide the PI0Fast policy for tests.""" + return pi0_fast_components[0] + + +@pytest.fixture(scope="module") +def preprocessor(pi0_fast_components): + """Fixture to provide the PI0Fast preprocessor for tests.""" + return pi0_fast_components[1] + + +@require_cuda +def test_pi0_fast_preprocessor_alignment(policy, preprocessor): + """Test that LeRobot PI0Fast preprocessor produces expected outputs.""" + print("\n" + "=" * 80) + print("Test: PI0Fast Preprocessor Outputs") + print("=" * 80) + + set_seed_all(42) + + print("\nCreating dummy data...") + batch = create_dummy_data() + + print("\n[LeRobot] Preprocessing...") + lerobot_observation = preprocessor(deepcopy(batch)) + + print("\nVerifying preprocessor outputs:") + print("-" * 80) + + # Expected keys from PI0Fast preprocessing + expected_keys = [ + "observation.images.base_0_rgb", + "observation.images.left_wrist_0_rgb", + "observation.images.right_wrist_0_rgb", + "observation.state", + "observation.language_tokens", + "observation.language_attention_mask", + ] + + for key in expected_keys: + if key in lerobot_observation: + shape = tuple(lerobot_observation[key].shape) + print(f"\nKey: {key}") + print(f"Shape: {shape}") + print(f"Dtype: {lerobot_observation[key].dtype}") + else: + print(f"\nKey '{key}' not found in inputs!") + + # Check language tokens shape + if "observation.language_tokens" in lerobot_observation: + lang_tokens = lerobot_observation["observation.language_tokens"] + print(f"\nLanguage tokens shape: {lang_tokens.shape}") + # Should have batch dimension and max_length from tokenizer + assert lang_tokens.dim() == 2, f"Expected 2D tensor, got {lang_tokens.dim()}D" + + print("\nPreprocessor outputs verified!") + + +@require_cuda +def test_pi0_fast_action_generation(policy, preprocessor): + """Test PI0Fast LeRobot implementation generates expected actions.""" + print("\n" + "=" * 80) + print("Test: PI0Fast Action Generation Against Expected Values") + print("=" * 80) + + set_seed_all(42) + + print("\nCreating dummy data...") + batch = create_dummy_data() + + print("\n[LeRobot] Running inference...") + lerobot_observation = preprocessor(deepcopy(batch)) + + # Reset seed for inference + torch.manual_seed(42) + with torch.no_grad(): + lerobot_actions = policy.predict_action_chunk(lerobot_observation) + lerobot_actions = lerobot_actions.float().cpu() + + print(f"LeRobot actions shape: {lerobot_actions.shape}") + print(f"LeRobot actions mean: {lerobot_actions.mean().item():.6f}") + print(f"LeRobot actions std: {lerobot_actions.std().item():.6f}") + print(f"LeRobot actions first 5: {lerobot_actions[0, 0, :5]}") + + print("\nExpected values (from original PI0Fast):") + print(f"Expected actions shape: {EXPECTED_ACTIONS_SHAPE}") + print(f"Expected actions mean: {EXPECTED_ACTIONS_MEAN:.6f}") + print(f"Expected actions std: {EXPECTED_ACTIONS_STD:.6f}") + print(f"Expected actions first 5: {EXPECTED_ACTIONS_FIRST_5}") + + print("\nAction Comparison:") + print("-" * 80) + + # Compare shapes + actual_shape = tuple(lerobot_actions.shape) + print(f"Actual shape: {actual_shape}") + + assert actual_shape == EXPECTED_ACTIONS_SHAPE, ( + f"Shape mismatch: {actual_shape} vs {EXPECTED_ACTIONS_SHAPE}" + ) + print(f"Shape matches: {actual_shape}") + + # Compare statistics + actual_mean = lerobot_actions.mean().item() + actual_std = lerobot_actions.std().item() + + print(f"\nMean: {actual_mean:.6f} (expected: {EXPECTED_ACTIONS_MEAN:.6f})") + print(f"Std: {actual_std:.6f} (expected: {EXPECTED_ACTIONS_STD:.6f})") + + # Compare first 5 actions + actual_first_5 = lerobot_actions[0, 0, :5] + print("\nFirst 5 actions comparison:") + print(f" Actual: {actual_first_5}") + print(f" Expected: {EXPECTED_ACTIONS_FIRST_5}") + + first_5_diff = torch.abs(actual_first_5 - EXPECTED_ACTIONS_FIRST_5) + print(f" Max diff: {first_5_diff.max().item():.6e}") + print(f" Mean diff: {first_5_diff.mean().item():.6e}") + + # Check with different tolerances + tolerances = [1e-5, 1e-4, 1e-3, 1e-2] + for tol in tolerances: + is_close = torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tol) + status = "Success" if is_close else "Failure" + print(f"{status}: First 5 actions close (atol={tol}): {is_close}") + + # Assert with reasonable tolerance + tolerance = 1e-3 + assert torch.allclose(actual_first_5, EXPECTED_ACTIONS_FIRST_5, atol=tolerance), ( + f"First 5 actions differ by more than tolerance ({tolerance})" + ) + print(f"\nSuccess: Actions match expected values within tolerance ({tolerance})!") + + print("\nAction generation test completed (values printed for reference)!") + + +@require_cuda +def test_pi0_fast_inference_reproducibility(policy, preprocessor): + """Test that PI0Fast inference is reproducible with the same seed.""" + print("\n" + "=" * 80) + print("Test: PI0Fast Inference Reproducibility") + print("=" * 80) + + print("\nCreating dummy data...") + batch = create_dummy_data() + + # First inference + print("\n[Run 1] Running inference...") + set_seed_all(42) + lerobot_observation = preprocessor(deepcopy(batch)) + with torch.no_grad(): + actions_1 = policy.predict_action_chunk(lerobot_observation) + actions_1 = actions_1.float().cpu() + + # Second inference with same seed + print("\n[Run 2] Running inference with same seed...") + set_seed_all(42) + lerobot_observation = preprocessor(deepcopy(batch)) + with torch.no_grad(): + actions_2 = policy.predict_action_chunk(lerobot_observation) + actions_2 = actions_2.float().cpu() + + print("\nComparing two runs:") + print("-" * 80) + if torch.allclose(actions_1, actions_2, atol=1e-8): + print("Inference is perfectly reproducible!") + else: + diff = torch.abs(actions_1 - actions_2) + print("Small differences detected:") + print(f" Max diff: {diff.max().item():.6e}") + print(f" Mean diff: {diff.mean().item():.6e}") + + assert torch.allclose(actions_1, actions_2, atol=1e-6), "Inference should be reproducible!" + + print("\nInference is reproducible!") + + +@require_cuda +def test_pi0_fast_forward_pass_logits(policy, preprocessor): + """Test PI0Fast forward pass and compare logits against expected values.""" + print("\n" + "=" * 80) + print("Test: PI0Fast Forward Pass Logits") + print("=" * 80) + + set_seed_all(42) + + print("\nCreating dummy data with action tokens...") + batch = create_dummy_data() + + # Preprocess the batch + lerobot_observation = preprocessor(deepcopy(batch)) + + # For forward pass, we need action tokens + # Create dummy action tokens for testing + batch_size = 1 + max_action_tokens = policy.config.max_action_tokens + + # Create dummy action tokens (in practice, these come from the FAST tokenizer) + dummy_action_tokens = torch.randint( + 0, 1000, (batch_size, max_action_tokens), dtype=torch.long, device=DEVICE + ) + dummy_action_masks = torch.ones(batch_size, max_action_tokens, dtype=torch.bool, device=DEVICE) + + # Add action tokens to the observation + lerobot_observation[ACTION_TOKENS] = dummy_action_tokens + lerobot_observation[ACTION_TOKEN_MASK] = dummy_action_masks + + print("\n[LeRobot] Running forward pass...") + policy.train() + with torch.no_grad(): + loss, loss_dict = policy.forward(lerobot_observation) + + print(f"Loss: {loss.item():.6f}") + print(f"FAST Loss: {loss_dict['ce_loss']:.6f}") + + print("\nForward pass completed successfully!") + print(f"Loss value: {loss.item():.6f}") + + # The loss should be a positive value + assert loss.item() > 0, "Loss should be positive" + assert not torch.isnan(loss), "Loss should not be NaN" + assert not torch.isinf(loss), "Loss should not be infinite" + + print("\nForward pass test passed!") + + +@require_cuda +def test_pi0_fast_action_token_sampling(policy, preprocessor): + """Test PI0Fast action token sampling (autoregressive decoding).""" + print("\n" + "=" * 80) + print("Test: PI0Fast Action Token Sampling") + print("=" * 80) + + set_seed_all(42) + + print("\nCreating dummy data...") + batch = create_dummy_data() + + print("\n[LeRobot] Preprocessing...") + lerobot_observation = preprocessor(deepcopy(batch)) + + # Prepare inputs for model + images, img_masks = policy._preprocess_images(lerobot_observation) + tokens = lerobot_observation[OBS_LANGUAGE_TOKENS] + masks = lerobot_observation[OBS_LANGUAGE_ATTENTION_MASK] + + print("\n[LeRobot] Sampling action tokens...") + torch.manual_seed(42) + with torch.no_grad(): + action_tokens = policy.model.sample_actions_fast( + images, + img_masks, + tokens, + masks, + max_decoding_steps=2, + temperature=0.0, # Greedy decoding for reproducibility + ) + + print(f"Action tokens shape: {action_tokens.shape}") + print(f"Action tokens first 10: {action_tokens[0, :10].tolist()}") + + print("\nExpected values (from original PI0Fast):") + print(f"Expected shape: {EXPECTED_ACTION_TOKENS_SHAPE}") + print(f"Expected first 5: {EXPECTED_ACTION_TOKENS_FIRST_5.tolist()}") + + # Verify shape + actual_shape = tuple(action_tokens.shape) + print(f"\nActual shape: {actual_shape}") + + assert actual_shape == EXPECTED_ACTION_TOKENS_SHAPE, ( + f"Shape mismatch: {actual_shape} vs {EXPECTED_ACTION_TOKENS_SHAPE}" + ) + + # Compare first 5 tokens + actual_first_5 = action_tokens[0, :5].cpu() + assert torch.equal(actual_first_5, EXPECTED_ACTION_TOKENS_FIRST_5), ( + f"First 5 tokens mismatch: {actual_first_5} vs {EXPECTED_ACTION_TOKENS_FIRST_5}" + ) + + print("\nAction token sampling test completed!") + + +@require_cuda +def test_pi0_fast_detokenization(policy, preprocessor): + """Test PI0Fast action detokenization (FAST decoding).""" + print("\n" + "=" * 80) + print("Test: PI0Fast Action Detokenization") + print("=" * 80) + + set_seed_all(42) + + print("\nCreating dummy data...") + batch = create_dummy_data() + + print("\n[LeRobot] Preprocessing...") + lerobot_observation = preprocessor(deepcopy(batch)) + + # Prepare inputs for model + images, img_masks = policy._preprocess_images(lerobot_observation) + tokens = lerobot_observation[OBS_LANGUAGE_TOKENS] + masks = lerobot_observation[OBS_LANGUAGE_ATTENTION_MASK] + + print("\n[LeRobot] Sampling action tokens...") + torch.manual_seed(42) + with torch.no_grad(): + action_tokens = policy.model.sample_actions_fast( + images, + img_masks, + tokens, + masks, + max_decoding_steps=2, + temperature=0.0, + ) + + print(f"Action tokens shape: {action_tokens.shape}") + + # Detokenize + print("\n[LeRobot] Detokenizing action tokens...") + action_horizon = policy.config.n_action_steps + action_dim = policy.config.output_features["action"].shape[0] + + try: + continuous_actions = policy.detokenize_actions( + action_tokens, action_horizon=action_horizon, action_dim=action_dim + ) + print(f"Continuous actions shape: {continuous_actions.shape}") + print(f"Continuous actions mean: {continuous_actions.mean().item():.6f}") + print(f"Continuous actions std: {continuous_actions.std().item():.6f}") + print(f"Continuous actions first 5: {continuous_actions[0, 0, :5]}") + print("\nDetokenization successful!") + except Exception as e: + print(f"\nDetokenization failed with error: {e}") + print("This may be expected if the action tokens are not valid FAST tokens.") + print("The test will pass as long as the sampling works correctly.")