Files
lerobot/tests/policies/pi0_pi05/utils/torch_compile.py
T
Haoming Song b74a551d38 fix(pi0, pi05): stabilize torch.compile and expand test coverage (#3610)
* chore(gr00t): sync with #3606 for fixing gr00t config crash

* fix(pi0&pi05): fix graph break caused by deepcopy of past_key_values in sample_actions

* fix(pi0&pi05): fix frequent recompile caused by compute_layer_complete

* feat(test): add compile test and benchamrk for pi0 and pi05

* feat(test): add comprehensive testing for pi0 and pi05. Including processor, forward, sample action, etc.
2026-05-22 10:29:34 +02:00

208 lines
8.0 KiB
Python

#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from collections.abc import Callable
import torch
from torch._dynamo.utils import counters, guard_failures
from torch.profiler import ProfilerActivity
FORWARD_RTOL = 1e-5
FORWARD_ATOL = 5e-2
SAMPLE_RTOL = 1e-5
SAMPLE_ATOL = 1e-2
COMPILE_MODE = "max-autotune"
STEADY_STATE_WARMUPS = 3
STEADY_STATE_REPEATS = 3
def make_compile_config(config_cls, *, compile_model):
return config_cls(device="cuda", compile_model=compile_model, compile_mode=COMPILE_MODE)
def counter_total(name):
return sum(counters.get(name, {}).values())
def compile_snapshot():
return {
"graph_breaks": counter_total("graph_break"),
"recompiles": counter_total("recompiles"),
"recompile_limits": counter_total("recompile_limit"),
"unique_graphs": counters["stats"].get("unique_graphs", 0),
}
def reset_compile_state():
torch._dynamo.reset()
counters.clear()
guard_failures.clear()
def clone_cuda_graph_output(output):
if torch.is_tensor(output):
return output.clone()
if isinstance(output, tuple):
return tuple(clone_cuda_graph_output(item) for item in output)
if isinstance(output, list):
return [clone_cuda_graph_output(item) for item in output]
if isinstance(output, dict):
return {key: clone_cuda_graph_output(value) for key, value in output.items()}
return output
def run_model_step(fn: Callable, kwargs: dict):
if hasattr(torch.compiler, "cudagraph_mark_step_begin"):
torch.compiler.cudagraph_mark_step_begin()
return fn(**kwargs)
def assert_explain_has_no_graph_breaks(fn: Callable, kwargs: dict, label: str):
reset_compile_state()
explanation = torch._dynamo.explain(fn)(**kwargs)
assert explanation.graph_count > 0, f"{label} was not captured by Dynamo"
assert explanation.graph_break_count == 0, (
f"{label} has {explanation.graph_break_count} graph break(s): {explanation.break_reasons}"
)
assert not explanation.break_reasons, f"{label} graph break reasons: {explanation.break_reasons}"
print(
f"{label} capture: graphs={explanation.graph_count}, "
f"graph_breaks={explanation.graph_break_count}, ops={explanation.op_count}, "
f"guards={len(explanation.out_guards or [])}"
)
return explanation
@torch.no_grad()
def assert_compiled_output_matches_eager(eager_model, compiled_model, forward_kwargs, sample_kwargs):
eager_forward = eager_model.forward(**forward_kwargs)
compiled_forward = compiled_model.forward(**forward_kwargs)
torch.testing.assert_close(compiled_forward, eager_forward, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
eager_actions = eager_model.sample_actions(**sample_kwargs)
compiled_actions = compiled_model.sample_actions(**sample_kwargs)
torch.testing.assert_close(compiled_actions, eager_actions, rtol=SAMPLE_RTOL, atol=SAMPLE_ATOL)
@torch.no_grad()
def assert_cache_stability(fn: Callable, kwargs: dict, label: str):
reset_compile_state()
first_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
first_snapshot = compile_snapshot()
second_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
second_snapshot = compile_snapshot()
third_output = clone_cuda_graph_output(run_model_step(fn, kwargs))
third_snapshot = compile_snapshot()
torch.testing.assert_close(second_output, first_output, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
torch.testing.assert_close(third_output, first_output, rtol=FORWARD_RTOL, atol=FORWARD_ATOL)
assert first_snapshot["unique_graphs"] > 0, f"{label} did not compile any graph"
assert third_snapshot["graph_breaks"] == 0, f"{label} graph breaks: {third_snapshot}"
assert third_snapshot["recompiles"] == 0, f"{label} recompiled: {third_snapshot}"
assert third_snapshot["recompile_limits"] == 0, f"{label} hit recompile limit: {third_snapshot}"
assert second_snapshot["unique_graphs"] == first_snapshot["unique_graphs"], (
f"{label} compiled new graph on second call: first={first_snapshot}, second={second_snapshot}"
)
assert third_snapshot["unique_graphs"] == first_snapshot["unique_graphs"], (
f"{label} compiled new graph on third call: first={first_snapshot}, third={third_snapshot}"
)
assert not guard_failures, f"{label} guard failures: {dict(guard_failures)}"
print(f"{label} cache: first={first_snapshot}, third={third_snapshot}")
@torch.no_grad()
def benchmark_runtime(eager_fn: Callable, compiled_fn: Callable, kwargs: dict, label: str):
run_warmups(eager_fn, kwargs)
run_warmups(compiled_fn, kwargs)
torch.cuda.synchronize()
eager_metrics = profile_callable(eager_fn, kwargs)
compiled_metrics = profile_callable(compiled_fn, kwargs)
speedup = eager_metrics["cuda_event_ms"] / compiled_metrics["cuda_event_ms"]
print(
f"{label} runtime: eager_cuda={eager_metrics['cuda_event_ms']:.3f} ms, "
f"compiled_cuda={compiled_metrics['cuda_event_ms']:.3f} ms, speedup={speedup:.3f}x, "
f"host_wall_ms eager/compiled={eager_metrics['host_wall_ms']:.3f}/"
f"{compiled_metrics['host_wall_ms']:.3f}, "
f"cpu_self_time_ms eager/compiled={eager_metrics['cpu_self_time_ms']:.3f}/"
f"{compiled_metrics['cpu_self_time_ms']:.3f}, "
f"cuda_launches eager/compiled={eager_metrics['cuda_launch_count']}/"
f"{compiled_metrics['cuda_launch_count']}, "
f"profiler_events eager/compiled={eager_metrics['profiler_event_count']}/"
f"{compiled_metrics['profiler_event_count']}, "
f"peak_mem_mib eager/compiled={eager_metrics['peak_mem_mib']:.1f}/"
f"{compiled_metrics['peak_mem_mib']:.1f}"
)
assert eager_metrics["cuda_event_ms"] > 0
assert compiled_metrics["cuda_event_ms"] > 0
assert eager_metrics["profiler_event_count"] > 0
assert compiled_metrics["profiler_event_count"] > 0
return eager_metrics, compiled_metrics
def run_warmups(fn: Callable, kwargs: dict):
for _ in range(STEADY_STATE_WARMUPS):
run_model_step(fn, kwargs)
torch.cuda.synchronize()
def profile_callable(fn: Callable, kwargs: dict):
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
host_start = time.perf_counter()
start_event.record()
for _ in range(STEADY_STATE_REPEATS):
run_model_step(fn, kwargs)
end_event.record()
torch.cuda.synchronize()
cuda_event_ms = start_event.elapsed_time(end_event) / STEADY_STATE_REPEATS
host_wall_ms = (time.perf_counter() - host_start) * 1000 / STEADY_STATE_REPEATS
peak_mem_mib = torch.cuda.max_memory_allocated() / 1024**2
with torch.profiler.profile(
activities=[ProfilerActivity.CPU],
) as profiler:
run_model_step(fn, kwargs)
torch.cuda.synchronize()
key_averages = profiler.key_averages()
cpu_self_time_ms = sum(event.self_cpu_time_total for event in key_averages) / 1000
cuda_launch_count = sum(
event.count
for event in key_averages
if event.key in {"cudaLaunchKernel", "cudaGraphLaunch", "cudaLaunchKernelExC"}
)
profiler_event_count = sum(event.count for event in key_averages)
return {
"cuda_event_ms": cuda_event_ms,
"host_wall_ms": host_wall_ms,
"cpu_self_time_ms": cpu_self_time_ms,
"cuda_launch_count": cuda_launch_count,
"profiler_event_count": profiler_event_count,
"peak_mem_mib": peak_mem_mib,
}