Files
lerobot/tests/processor/test_pi05_processor.py
T
Haoming Song b74a551d38 fix(pi0, pi05): stabilize torch.compile and expand test coverage (#3610)
* chore(gr00t): sync with #3606 for fixing gr00t config crash

* fix(pi0&pi05): fix graph break caused by deepcopy of past_key_values in sample_actions

* fix(pi0&pi05): fix frequent recompile caused by compute_layer_complete

* feat(test): add compile test and benchamrk for pi0 and pi05

* feat(test): add comprehensive testing for pi0 and pi05. Including processor, forward, sample action, etc.
2026-05-22 10:29:34 +02:00

156 lines
5.1 KiB
Python

#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Compare the PI0.5 processor pipeline against the vendored OpenPI reference processors."""
import os
import pytest
import torch
pytest.importorskip("transformers")
from lerobot.configs import FeatureType, PolicyFeature # noqa: E402
from lerobot.policies.pi05 import PI05Policy # noqa: E402
from lerobot.policies.pi05.configuration_pi05 import PI05Config # noqa: E402
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
IMAGE_KEYS,
assert_processor_inputs_match_lerobot,
clone_batch,
make_openpi_observation_from_raw,
openpi_model_actions_from_raw,
)
pytestmark = pytest.mark.skipif(
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
reason="OpenPI processor parity uses the PaliGemma tokenizer; run manually outside CI.",
)
DUMMY_ACTION_DIM = 32
DUMMY_STATE_DIM = 32
DUMMY_ACTION_HORIZON = 50
DUMMY_MAX_TOKEN_LEN = 200
DEVICE = torch.device("cpu")
DUMMY_DATASET_STATS = {
OBS_STATE: {
"mean": torch.zeros(DUMMY_STATE_DIM),
"std": torch.ones(DUMMY_STATE_DIM),
"q01": torch.zeros(DUMMY_STATE_DIM),
"q99": torch.ones(DUMMY_STATE_DIM),
},
ACTION: {
"mean": torch.zeros(DUMMY_ACTION_DIM),
"std": torch.ones(DUMMY_ACTION_DIM),
"q01": torch.zeros(DUMMY_ACTION_DIM),
"q99": torch.ones(DUMMY_ACTION_DIM),
},
"images": {
key: {
"mean": torch.zeros(3, 224, 224),
"std": torch.ones(3, 224, 224),
"q01": torch.zeros(3, 224, 224),
"q99": torch.ones(3, 224, 224),
}
for key in IMAGE_KEYS
},
}
class PI05PolicyInputAdapter(torch.nn.Module):
"""Minimal adapter exposing PI0.5 policy image preparation without loading model weights."""
_preprocess_images = PI05Policy._preprocess_images
def __init__(self, config: PI05Config) -> None:
super().__init__()
self.config = config
self._device_anchor = torch.nn.Parameter(torch.empty((), device=config.device), requires_grad=False)
def create_pi05_config() -> PI05Config:
config = PI05Config(device=str(DEVICE))
config.max_state_dim = DUMMY_STATE_DIM
config.max_action_dim = DUMMY_ACTION_DIM
config.chunk_size = DUMMY_ACTION_HORIZON
config.n_action_steps = DUMMY_ACTION_HORIZON
config.tokenizer_max_length = DUMMY_MAX_TOKEN_LEN
config.input_features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(DUMMY_STATE_DIM,)),
**{
f"observation.images.{key}": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224))
for key in IMAGE_KEYS
},
}
config.output_features = {
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(DUMMY_ACTION_DIM,)),
}
return config
def create_dummy_data() -> dict:
batch_size = 2
prompt = "Pick up the red block and place it in the bin"
return {
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
ACTION: torch.randn(
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
),
**{
f"observation.images.{key}": torch.rand(
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
)
for key in IMAGE_KEYS
},
"task": [prompt for _ in range(batch_size)],
}
def test_pi05_processor_inputs_match_openpi_reference():
torch.manual_seed(0)
config = create_pi05_config()
preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=DUMMY_DATASET_STATS)
raw_batch = create_dummy_data()
lerobot_batch = preprocessor(clone_batch(raw_batch))
openpi_observation = make_openpi_observation_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
max_token_len=DUMMY_MAX_TOKEN_LEN,
dataset_stats=DUMMY_DATASET_STATS,
pi05=True,
)
assert_processor_inputs_match_lerobot(
PI05PolicyInputAdapter(config),
lerobot_batch,
openpi_observation,
compare_state=False,
)
torch.testing.assert_close(
lerobot_batch[ACTION],
openpi_model_actions_from_raw(
raw_batch,
action_dim=DUMMY_ACTION_DIM,
dataset_stats=DUMMY_DATASET_STATS,
pi05=True,
),
rtol=0,
atol=0,
)