fix(profiling): preserve policy mode for deterministic forward

This commit is contained in:
Pepijn
2026-04-16 09:50:29 +02:00
parent 9dc38d9993
commit ed8a98dda6
2 changed files with 39 additions and 4 deletions
+36
View File
@@ -23,6 +23,8 @@ import subprocess
import sys
from pathlib import Path
import torch
def _import_model_profiling_script():
script_path = Path(__file__).resolve().parents[2] / "scripts" / "ci" / "run_model_profiling.py"
@@ -184,3 +186,37 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
assert row["step_timing_summary"]["forward_s"]["mean"] == 0.1
assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint"
assert "policy_setup" in row["artifact_paths"]["cprofile_summaries"]
def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path):
from lerobot.utils.profiling_utils import write_deterministic_forward_artifacts
class _TrainingOnlyPolicy(torch.nn.Module):
def __init__(self):
super().__init__()
self.forward_calls = 0
def forward(self, batch):
self.forward_calls += 1
assert self.training
return batch["value"].sum(), {"value": batch["value"]}
dataset = [{"value": torch.tensor([1.0, 2.0])}]
policy = _TrainingOnlyPolicy()
policy.train()
write_deterministic_forward_artifacts(
policy=policy,
dataset=dataset,
batch_size=2,
preprocessor=lambda batch: batch,
output_dir=tmp_path,
device_type="cpu",
)
payload = json.loads((tmp_path / "deterministic_forward.json").read_text())
assert policy.training is True
assert policy.forward_calls == 1
assert payload["reference_batch_size"] == 2
assert "operator_fingerprint" in payload
assert payload["outputs"]["loss"]["numel"] == 1