mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 22:49:48 +00:00
b74a551d38
* chore(gr00t): sync with #3606 for fixing gr00t config crash * fix(pi0&pi05): fix graph break caused by deepcopy of past_key_values in sample_actions * fix(pi0&pi05): fix frequent recompile caused by compute_layer_complete * feat(test): add compile test and benchamrk for pi0 and pi05 * feat(test): add comprehensive testing for pi0 and pi05. Including processor, forward, sample action, etc.
292 lines
11 KiB
Python
292 lines
11 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Iterator
|
|
from contextlib import contextmanager
|
|
from dataclasses import dataclass
|
|
from functools import lru_cache
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import safetensors.torch
|
|
import torch
|
|
import torch.nn.functional as F # noqa: N812
|
|
from huggingface_hub import snapshot_download
|
|
from transformers import AutoTokenizer
|
|
|
|
from lerobot.utils.constants import (
|
|
ACTION,
|
|
OBS_LANGUAGE_ATTENTION_MASK,
|
|
OBS_LANGUAGE_TOKENS,
|
|
OBS_STATE,
|
|
)
|
|
from tests.policies.pi0_pi05.openpi_pytorch import preprocessing_pytorch as openpi_preprocessing
|
|
|
|
IMAGE_KEYS = ("base_0_rgb", "left_wrist_0_rgb", "right_wrist_0_rgb")
|
|
TOKENIZER_NAME = "google/paligemma-3b-pt-224"
|
|
|
|
|
|
@dataclass
|
|
class OpenPIObservation:
|
|
state: torch.Tensor
|
|
images: dict[str, torch.Tensor]
|
|
image_masks: dict[str, torch.Tensor]
|
|
tokenized_prompt: torch.Tensor
|
|
tokenized_prompt_mask: torch.Tensor
|
|
token_ar_mask: torch.Tensor
|
|
token_loss_mask: torch.Tensor
|
|
|
|
|
|
@lru_cache(maxsize=1)
|
|
def paligemma_tokenizer():
|
|
return AutoTokenizer.from_pretrained(TOKENIZER_NAME)
|
|
|
|
|
|
def clone_batch(batch: dict) -> dict:
|
|
return {
|
|
key: value.clone() if isinstance(value, torch.Tensor) else list(value) for key, value in batch.items()
|
|
}
|
|
|
|
|
|
def pad_last_dim(tensor: torch.Tensor, target_dim: int) -> torch.Tensor:
|
|
if tensor.shape[-1] > target_dim:
|
|
raise ValueError(f"Cannot pad last dimension {tensor.shape[-1]} down to {target_dim}")
|
|
return F.pad(tensor, (0, target_dim - tensor.shape[-1]))
|
|
|
|
|
|
def mean_std_normalize(tensor: torch.Tensor, stats: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
mean = stats["mean"].to(device=tensor.device, dtype=tensor.dtype)
|
|
std = stats["std"].to(device=tensor.device, dtype=tensor.dtype)
|
|
return (tensor - mean) / (std + 1e-8)
|
|
|
|
|
|
def quantile_normalize(tensor: torch.Tensor, stats: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
q01 = stats["q01"].to(device=tensor.device, dtype=tensor.dtype)
|
|
q99 = stats["q99"].to(device=tensor.device, dtype=tensor.dtype)
|
|
denom = torch.where(q99 == q01, torch.full_like(q99, 1e-8), q99 - q01)
|
|
return 2.0 * (tensor - q01) / denom - 1.0
|
|
|
|
|
|
def openpi_model_state_from_raw(
|
|
batch: dict[str, torch.Tensor],
|
|
*,
|
|
action_dim: int,
|
|
dataset_stats: dict[str, dict[str, torch.Tensor]],
|
|
pi05: bool,
|
|
) -> torch.Tensor:
|
|
state = batch[OBS_STATE].to(dtype=torch.float32)
|
|
if pi05:
|
|
state = quantile_normalize(state, dataset_stats[OBS_STATE])
|
|
else:
|
|
state = mean_std_normalize(state, dataset_stats[OBS_STATE])
|
|
return pad_last_dim(state, action_dim)
|
|
|
|
|
|
def openpi_model_actions_from_raw(
|
|
batch: dict[str, torch.Tensor],
|
|
*,
|
|
action_dim: int,
|
|
dataset_stats: dict[str, dict[str, torch.Tensor]],
|
|
pi05: bool,
|
|
) -> torch.Tensor:
|
|
actions = batch[ACTION].to(dtype=torch.float32)
|
|
if pi05:
|
|
actions = quantile_normalize(actions, dataset_stats[ACTION])
|
|
else:
|
|
actions = mean_std_normalize(actions, dataset_stats[ACTION])
|
|
return pad_last_dim(actions, action_dim)
|
|
|
|
|
|
def _tasks_from_raw(batch: dict, batch_size: int) -> list[str]:
|
|
tasks = batch.get("task")
|
|
if tasks is None:
|
|
raise ValueError("The parity batch must include a task prompt.")
|
|
if isinstance(tasks, str):
|
|
return [tasks] * batch_size
|
|
if len(tasks) == 1:
|
|
return [tasks[0]] * batch_size
|
|
if len(tasks) != batch_size:
|
|
raise ValueError(f"Expected {batch_size} task prompts, got {len(tasks)}")
|
|
return list(tasks)
|
|
|
|
|
|
def _format_pi0_prompts(tasks: list[str]) -> list[str]:
|
|
return [f"{task.strip().replace('_', ' ').replace(chr(10), ' ')}\n" for task in tasks]
|
|
|
|
|
|
def _format_pi05_prompts(tasks: list[str], normalized_state: torch.Tensor) -> list[str]:
|
|
state_np = normalized_state.detach().cpu().numpy()
|
|
discretized_states = np.digitize(state_np, bins=np.linspace(-1, 1, 256 + 1)[:-1]) - 1
|
|
prompts = []
|
|
for task, state in zip(tasks, discretized_states, strict=True):
|
|
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
|
state_str = " ".join(map(str, state))
|
|
prompts.append(f"Task: {cleaned_text}, State: {state_str};\nAction: ")
|
|
return prompts
|
|
|
|
|
|
def _tokenize_prompts(prompts: list[str], *, max_token_len: int, device: torch.device | str):
|
|
tokenized = paligemma_tokenizer()(
|
|
prompts,
|
|
padding="max_length",
|
|
padding_side="right",
|
|
truncation=True,
|
|
max_length=max_token_len,
|
|
return_tensors="pt",
|
|
)
|
|
tokens = tokenized["input_ids"].to(device)
|
|
masks = tokenized["attention_mask"].to(device=device, dtype=torch.bool)
|
|
return tokens, masks
|
|
|
|
|
|
def make_openpi_observation_from_raw(
|
|
batch: dict[str, torch.Tensor],
|
|
*,
|
|
action_dim: int,
|
|
max_token_len: int,
|
|
dataset_stats: dict[str, dict[str, torch.Tensor]],
|
|
pi05: bool,
|
|
) -> OpenPIObservation:
|
|
batch_size = batch[OBS_STATE].shape[0]
|
|
device = batch[OBS_STATE].device
|
|
state = openpi_model_state_from_raw(
|
|
batch,
|
|
action_dim=action_dim,
|
|
dataset_stats=dataset_stats,
|
|
pi05=pi05,
|
|
)
|
|
|
|
tasks = _tasks_from_raw(batch, batch_size)
|
|
prompts = _format_pi05_prompts(tasks, state) if pi05 else _format_pi0_prompts(tasks)
|
|
tokens, masks = _tokenize_prompts(prompts, max_token_len=max_token_len, device=device)
|
|
|
|
images = {
|
|
key: batch[f"observation.images.{key}"].to(device=device, dtype=torch.float32) * 2.0 - 1.0
|
|
for key in IMAGE_KEYS
|
|
}
|
|
image_masks = {key: torch.ones(batch_size, dtype=torch.bool, device=device) for key in IMAGE_KEYS}
|
|
|
|
return OpenPIObservation(
|
|
state=state,
|
|
images=images,
|
|
image_masks=image_masks,
|
|
tokenized_prompt=tokens,
|
|
tokenized_prompt_mask=masks,
|
|
token_ar_mask=torch.zeros_like(tokens, dtype=torch.int32),
|
|
token_loss_mask=torch.ones_like(masks, dtype=torch.bool),
|
|
)
|
|
|
|
|
|
def assert_processor_inputs_match_lerobot(
|
|
lerobot_policy,
|
|
lerobot_batch: dict[str, torch.Tensor],
|
|
openpi_observation: OpenPIObservation,
|
|
*,
|
|
compare_state: bool,
|
|
):
|
|
openpi_processed = openpi_preprocessing.preprocess_observation_pytorch(openpi_observation, train=False)
|
|
lerobot_images, lerobot_image_masks = lerobot_policy._preprocess_images(lerobot_batch)
|
|
|
|
# Token IDs, token masks, images, image masks, and PI0 state are intentionally built from the same
|
|
# raw batch through independent LeRobot/OpenPI-style processor logic. They must be bitwise equal.
|
|
torch.testing.assert_close(
|
|
openpi_observation.tokenized_prompt, lerobot_batch[OBS_LANGUAGE_TOKENS], rtol=0, atol=0
|
|
)
|
|
torch.testing.assert_close(
|
|
openpi_observation.tokenized_prompt_mask,
|
|
lerobot_batch[OBS_LANGUAGE_ATTENTION_MASK],
|
|
rtol=0,
|
|
atol=0,
|
|
)
|
|
|
|
for openpi_image, lerobot_image in zip(openpi_processed.images.values(), lerobot_images, strict=True):
|
|
torch.testing.assert_close(openpi_image, lerobot_image, rtol=0, atol=0)
|
|
|
|
for openpi_mask, lerobot_mask in zip(
|
|
openpi_processed.image_masks.values(), lerobot_image_masks, strict=True
|
|
):
|
|
torch.testing.assert_close(openpi_mask, lerobot_mask, rtol=0, atol=0)
|
|
|
|
if compare_state:
|
|
torch.testing.assert_close(
|
|
openpi_processed.state, lerobot_policy.prepare_state(lerobot_batch), rtol=0, atol=0
|
|
)
|
|
|
|
|
|
def load_openpi_reference_state_dict(repo_id: str) -> dict[str, torch.Tensor]:
|
|
cache_dir = Path(snapshot_download(repo_id=repo_id, repo_type="model"))
|
|
return safetensors.torch.load_file(cache_dir / "model.safetensors")
|
|
|
|
|
|
def fix_reference_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
fixed_state_dict = dict(state_dict)
|
|
lm_head_key = "paligemma_with_expert.paligemma.lm_head.weight"
|
|
embed_tokens_key = "paligemma_with_expert.paligemma.model.language_model.embed_tokens.weight"
|
|
if lm_head_key in fixed_state_dict and embed_tokens_key not in fixed_state_dict:
|
|
fixed_state_dict[embed_tokens_key] = fixed_state_dict[lm_head_key].clone()
|
|
return fixed_state_dict
|
|
|
|
|
|
@contextmanager
|
|
def fixed_flow_sampling(model, *, noise: torch.Tensor, time: torch.Tensor) -> Iterator[None]:
|
|
original_sample_noise = model.sample_noise
|
|
original_sample_time = model.sample_time
|
|
|
|
def sample_noise(shape, device):
|
|
if tuple(shape) != tuple(noise.shape):
|
|
raise ValueError(f"Expected noise shape {tuple(noise.shape)}, got {tuple(shape)}")
|
|
return noise.to(device=device)
|
|
|
|
def sample_time(batch_size, device):
|
|
if batch_size != time.shape[0]:
|
|
raise ValueError(f"Expected time batch size {time.shape[0]}, got {batch_size}")
|
|
return time.to(device=device)
|
|
|
|
model.sample_noise = sample_noise
|
|
model.sample_time = sample_time
|
|
try:
|
|
yield
|
|
finally:
|
|
model.sample_noise = original_sample_noise
|
|
model.sample_time = original_sample_time
|
|
|
|
|
|
@contextmanager
|
|
def deterministic_openpi_forward_preprocess(openpi_policy) -> Iterator[None]:
|
|
"""Disable OpenPI's training-time image augmentation only inside a parity forward block.
|
|
|
|
OpenPI's `forward()` calls `_preprocess_observation(..., train=True)`, which can apply stochastic
|
|
image augmentation. LeRobot's policy forward path does not apply that augmentation, so parity would
|
|
otherwise compare two different image tensors rather than two model implementations. The context manager
|
|
keeps the public `openpi_policy.forward(observation, ...)` call while making preprocessing deterministic.
|
|
|
|
`yield` marks the body of the caller's `with` block. The `try/finally` restores the original method even
|
|
if the assertion inside the block fails, so the temporary monkeypatch cannot leak into later tests.
|
|
"""
|
|
|
|
original_preprocess_observation = openpi_policy._preprocess_observation
|
|
|
|
def preprocess_observation(observation, *, train=True):
|
|
return original_preprocess_observation(observation, train=False)
|
|
|
|
openpi_policy._preprocess_observation = preprocess_observation
|
|
try:
|
|
yield
|
|
finally:
|
|
openpi_policy._preprocess_observation = original_preprocess_observation
|