From ed8a98dda65121fa7d1a827a5d1866bb10b19f7d Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 16 Apr 2026 09:50:29 +0200 Subject: [PATCH] fix(profiling): preserve policy mode for deterministic forward --- src/lerobot/utils/profiling_utils.py | 7 +++--- tests/scripts/test_model_profiling.py | 36 +++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/src/lerobot/utils/profiling_utils.py b/src/lerobot/utils/profiling_utils.py index 84738f24c..c3625363b 100644 --- a/src/lerobot/utils/profiling_utils.py +++ b/src/lerobot/utils/profiling_utils.py @@ -193,16 +193,15 @@ def write_deterministic_forward_artifacts( if device_type == "cuda": activities.append(torch.profiler.ProfilerActivity.CUDA) - was_training = policy.training - policy.eval() + # Keep the caller-selected module mode so the fingerprint matches the actual + # train-path forward used by the policy. Some policies, such as ACT with VAE, + # only materialize their full forward outputs while in training mode. with torch.random.fork_rng(devices=[] if device_type != "cuda" else None): torch.manual_seed(0) if device_type == "cuda": torch.cuda.manual_seed_all(0) with torch.no_grad(), torch.profiler.profile(activities=activities) as profiler: loss, output_dict = policy.forward(reference_batch) - if was_training: - policy.train() operator_entries = [] for event in profiler.key_averages(): diff --git a/tests/scripts/test_model_profiling.py b/tests/scripts/test_model_profiling.py index a9efb4ce5..61dae1232 100644 --- a/tests/scripts/test_model_profiling.py +++ b/tests/scripts/test_model_profiling.py @@ -23,6 +23,8 @@ import subprocess import sys from pathlib import Path +import torch + def _import_model_profiling_script(): script_path = Path(__file__).resolve().parents[2] / "scripts" / "ci" / "run_model_profiling.py" @@ -184,3 +186,37 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path): assert row["step_timing_summary"]["forward_s"]["mean"] == 0.1 assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint" assert "policy_setup" in row["artifact_paths"]["cprofile_summaries"] + + +def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path): + from lerobot.utils.profiling_utils import write_deterministic_forward_artifacts + + class _TrainingOnlyPolicy(torch.nn.Module): + def __init__(self): + super().__init__() + self.forward_calls = 0 + + def forward(self, batch): + self.forward_calls += 1 + assert self.training + return batch["value"].sum(), {"value": batch["value"]} + + dataset = [{"value": torch.tensor([1.0, 2.0])}] + policy = _TrainingOnlyPolicy() + policy.train() + + write_deterministic_forward_artifacts( + policy=policy, + dataset=dataset, + batch_size=2, + preprocessor=lambda batch: batch, + output_dir=tmp_path, + device_type="cpu", + ) + + payload = json.loads((tmp_path / "deterministic_forward.json").read_text()) + assert policy.training is True + assert policy.forward_calls == 1 + assert payload["reference_batch_size"] == 2 + assert "operator_fingerprint" in payload + assert payload["outputs"]["loss"]["numel"] == 1