mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 14:09:47 +00:00
b74a551d38
* chore(gr00t): sync with #3606 for fixing gr00t config crash * fix(pi0&pi05): fix graph break caused by deepcopy of past_key_values in sample_actions * fix(pi0&pi05): fix frequent recompile caused by compute_layer_complete * feat(test): add compile test and benchamrk for pi0 and pi05 * feat(test): add comprehensive testing for pi0 and pi05. Including processor, forward, sample action, etc.
102 lines
3.7 KiB
Python
102 lines
3.7 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import os
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
pytest.importorskip("transformers")
|
|
|
|
from lerobot.policies.pi05 import PI05Config # noqa: E402
|
|
from lerobot.policies.pi05.modeling_pi05 import PI05Pytorch # noqa: E402
|
|
from tests.policies.pi0_pi05.utils.torch_compile import ( # noqa: E402
|
|
assert_cache_stability,
|
|
assert_compiled_output_matches_eager,
|
|
assert_explain_has_no_graph_breaks,
|
|
benchmark_runtime,
|
|
make_compile_config,
|
|
reset_compile_state,
|
|
)
|
|
from tests.utils import require_cuda # noqa: E402
|
|
|
|
pytestmark = pytest.mark.skipif(
|
|
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
|
reason="torch.compile benchmark is too slow for CI; run manually on GPU nodes",
|
|
)
|
|
|
|
|
|
def _make_model(*, compile_model):
|
|
return PI05Pytorch(make_compile_config(PI05Config, compile_model=compile_model)).cuda().eval()
|
|
|
|
|
|
def _make_dummy_inputs(config):
|
|
device = torch.device("cuda")
|
|
common = {
|
|
"images": [torch.randn(1, 3, *config.image_resolution, device=device)],
|
|
"img_masks": [torch.ones(1, dtype=torch.bool, device=device)],
|
|
"tokens": torch.randint(0, 1024, (1, 5), dtype=torch.long, device=device),
|
|
"masks": torch.ones(1, 5, dtype=torch.bool, device=device),
|
|
}
|
|
forward_kwargs = {
|
|
**common,
|
|
"actions": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
|
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
|
"time": torch.rand(1, device=device),
|
|
}
|
|
sample_kwargs = {
|
|
**common,
|
|
"noise": torch.randn(1, config.chunk_size, config.max_action_dim, device=device),
|
|
"num_steps": config.num_inference_steps,
|
|
}
|
|
return forward_kwargs, sample_kwargs
|
|
|
|
|
|
@require_cuda
|
|
def test_pi05_torch_compile_forward_and_sample_actions():
|
|
if not hasattr(torch, "compile"):
|
|
pytest.skip("torch.compile is not available")
|
|
if not torch._dynamo.is_dynamo_supported():
|
|
pytest.skip("torch._dynamo is not supported on this platform")
|
|
|
|
torch.manual_seed(0)
|
|
eager_model = _make_model(compile_model=False)
|
|
torch.manual_seed(0)
|
|
compiled_model = _make_model(compile_model=True)
|
|
forward_kwargs, sample_kwargs = _make_dummy_inputs(compiled_model.config)
|
|
|
|
try:
|
|
assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs)
|
|
|
|
assert_explain_has_no_graph_breaks(eager_model.forward, forward_kwargs, "pi05.forward")
|
|
assert_explain_has_no_graph_breaks(eager_model.sample_actions, sample_kwargs, "pi05.sample_actions")
|
|
|
|
assert_cache_stability(compiled_model.forward, forward_kwargs, "pi05.forward")
|
|
assert_cache_stability(compiled_model.sample_actions, sample_kwargs, "pi05.sample_actions")
|
|
|
|
benchmark_runtime(eager_model.forward, compiled_model.forward, forward_kwargs, "pi05.forward")
|
|
benchmark_runtime(
|
|
eager_model.sample_actions,
|
|
compiled_model.sample_actions,
|
|
sample_kwargs,
|
|
"pi05.sample_actions",
|
|
)
|
|
finally:
|
|
reset_compile_state()
|
|
del eager_model
|
|
del compiled_model
|
|
torch.cuda.empty_cache()
|