mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
b74a551d38
* chore(gr00t): sync with #3606 for fixing gr00t config crash * fix(pi0&pi05): fix graph break caused by deepcopy of past_key_values in sample_actions * fix(pi0&pi05): fix frequent recompile caused by compute_layer_complete * feat(test): add compile test and benchamrk for pi0 and pi05 * feat(test): add comprehensive testing for pi0 and pi05. Including processor, forward, sample action, etc.
156 lines
5.1 KiB
Python
156 lines
5.1 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Compare the PI0.5 processor pipeline against the vendored OpenPI reference processors."""
|
|
|
|
import os
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
pytest.importorskip("transformers")
|
|
|
|
from lerobot.configs import FeatureType, PolicyFeature # noqa: E402
|
|
from lerobot.policies.pi05 import PI05Policy # noqa: E402
|
|
from lerobot.policies.pi05.configuration_pi05 import PI05Config # noqa: E402
|
|
from lerobot.policies.pi05.processor_pi05 import make_pi05_pre_post_processors # noqa: E402
|
|
from lerobot.utils.constants import ACTION, OBS_STATE # noqa: E402
|
|
from tests.policies.pi0_pi05.utils.openpi_parity import ( # noqa: E402
|
|
IMAGE_KEYS,
|
|
assert_processor_inputs_match_lerobot,
|
|
clone_batch,
|
|
make_openpi_observation_from_raw,
|
|
openpi_model_actions_from_raw,
|
|
)
|
|
|
|
pytestmark = pytest.mark.skipif(
|
|
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
|
reason="OpenPI processor parity uses the PaliGemma tokenizer; run manually outside CI.",
|
|
)
|
|
|
|
DUMMY_ACTION_DIM = 32
|
|
DUMMY_STATE_DIM = 32
|
|
DUMMY_ACTION_HORIZON = 50
|
|
DUMMY_MAX_TOKEN_LEN = 200
|
|
DEVICE = torch.device("cpu")
|
|
|
|
DUMMY_DATASET_STATS = {
|
|
OBS_STATE: {
|
|
"mean": torch.zeros(DUMMY_STATE_DIM),
|
|
"std": torch.ones(DUMMY_STATE_DIM),
|
|
"q01": torch.zeros(DUMMY_STATE_DIM),
|
|
"q99": torch.ones(DUMMY_STATE_DIM),
|
|
},
|
|
ACTION: {
|
|
"mean": torch.zeros(DUMMY_ACTION_DIM),
|
|
"std": torch.ones(DUMMY_ACTION_DIM),
|
|
"q01": torch.zeros(DUMMY_ACTION_DIM),
|
|
"q99": torch.ones(DUMMY_ACTION_DIM),
|
|
},
|
|
"images": {
|
|
key: {
|
|
"mean": torch.zeros(3, 224, 224),
|
|
"std": torch.ones(3, 224, 224),
|
|
"q01": torch.zeros(3, 224, 224),
|
|
"q99": torch.ones(3, 224, 224),
|
|
}
|
|
for key in IMAGE_KEYS
|
|
},
|
|
}
|
|
|
|
|
|
class PI05PolicyInputAdapter(torch.nn.Module):
|
|
"""Minimal adapter exposing PI0.5 policy image preparation without loading model weights."""
|
|
|
|
_preprocess_images = PI05Policy._preprocess_images
|
|
|
|
def __init__(self, config: PI05Config) -> None:
|
|
super().__init__()
|
|
self.config = config
|
|
self._device_anchor = torch.nn.Parameter(torch.empty((), device=config.device), requires_grad=False)
|
|
|
|
|
|
def create_pi05_config() -> PI05Config:
|
|
config = PI05Config(device=str(DEVICE))
|
|
config.max_state_dim = DUMMY_STATE_DIM
|
|
config.max_action_dim = DUMMY_ACTION_DIM
|
|
config.chunk_size = DUMMY_ACTION_HORIZON
|
|
config.n_action_steps = DUMMY_ACTION_HORIZON
|
|
config.tokenizer_max_length = DUMMY_MAX_TOKEN_LEN
|
|
config.input_features = {
|
|
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(DUMMY_STATE_DIM,)),
|
|
**{
|
|
f"observation.images.{key}": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224))
|
|
for key in IMAGE_KEYS
|
|
},
|
|
}
|
|
config.output_features = {
|
|
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(DUMMY_ACTION_DIM,)),
|
|
}
|
|
return config
|
|
|
|
|
|
def create_dummy_data() -> dict:
|
|
batch_size = 2
|
|
prompt = "Pick up the red block and place it in the bin"
|
|
return {
|
|
OBS_STATE: torch.randn(batch_size, DUMMY_STATE_DIM, dtype=torch.float32, device=DEVICE),
|
|
ACTION: torch.randn(
|
|
batch_size, DUMMY_ACTION_HORIZON, DUMMY_ACTION_DIM, dtype=torch.float32, device=DEVICE
|
|
),
|
|
**{
|
|
f"observation.images.{key}": torch.rand(
|
|
batch_size, 3, 224, 224, dtype=torch.float32, device=DEVICE
|
|
)
|
|
for key in IMAGE_KEYS
|
|
},
|
|
"task": [prompt for _ in range(batch_size)],
|
|
}
|
|
|
|
|
|
def test_pi05_processor_inputs_match_openpi_reference():
|
|
torch.manual_seed(0)
|
|
config = create_pi05_config()
|
|
preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=DUMMY_DATASET_STATS)
|
|
|
|
raw_batch = create_dummy_data()
|
|
lerobot_batch = preprocessor(clone_batch(raw_batch))
|
|
openpi_observation = make_openpi_observation_from_raw(
|
|
raw_batch,
|
|
action_dim=DUMMY_ACTION_DIM,
|
|
max_token_len=DUMMY_MAX_TOKEN_LEN,
|
|
dataset_stats=DUMMY_DATASET_STATS,
|
|
pi05=True,
|
|
)
|
|
|
|
assert_processor_inputs_match_lerobot(
|
|
PI05PolicyInputAdapter(config),
|
|
lerobot_batch,
|
|
openpi_observation,
|
|
compare_state=False,
|
|
)
|
|
torch.testing.assert_close(
|
|
lerobot_batch[ACTION],
|
|
openpi_model_actions_from_raw(
|
|
raw_batch,
|
|
action_dim=DUMMY_ACTION_DIM,
|
|
dataset_stats=DUMMY_DATASET_STATS,
|
|
pi05=True,
|
|
),
|
|
rtol=0,
|
|
atol=0,
|
|
)
|