mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
fix(profiling): preserve policy mode for deterministic forward
This commit is contained in:
@@ -193,16 +193,15 @@ def write_deterministic_forward_artifacts(
|
|||||||
if device_type == "cuda":
|
if device_type == "cuda":
|
||||||
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
||||||
|
|
||||||
was_training = policy.training
|
# Keep the caller-selected module mode so the fingerprint matches the actual
|
||||||
policy.eval()
|
# train-path forward used by the policy. Some policies, such as ACT with VAE,
|
||||||
|
# only materialize their full forward outputs while in training mode.
|
||||||
with torch.random.fork_rng(devices=[] if device_type != "cuda" else None):
|
with torch.random.fork_rng(devices=[] if device_type != "cuda" else None):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
if device_type == "cuda":
|
if device_type == "cuda":
|
||||||
torch.cuda.manual_seed_all(0)
|
torch.cuda.manual_seed_all(0)
|
||||||
with torch.no_grad(), torch.profiler.profile(activities=activities) as profiler:
|
with torch.no_grad(), torch.profiler.profile(activities=activities) as profiler:
|
||||||
loss, output_dict = policy.forward(reference_batch)
|
loss, output_dict = policy.forward(reference_batch)
|
||||||
if was_training:
|
|
||||||
policy.train()
|
|
||||||
|
|
||||||
operator_entries = []
|
operator_entries = []
|
||||||
for event in profiler.key_averages():
|
for event in profiler.key_averages():
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ import subprocess
|
|||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def _import_model_profiling_script():
|
def _import_model_profiling_script():
|
||||||
script_path = Path(__file__).resolve().parents[2] / "scripts" / "ci" / "run_model_profiling.py"
|
script_path = Path(__file__).resolve().parents[2] / "scripts" / "ci" / "run_model_profiling.py"
|
||||||
@@ -184,3 +186,37 @@ def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
|
|||||||
assert row["step_timing_summary"]["forward_s"]["mean"] == 0.1
|
assert row["step_timing_summary"]["forward_s"]["mean"] == 0.1
|
||||||
assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint"
|
assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint"
|
||||||
assert "policy_setup" in row["artifact_paths"]["cprofile_summaries"]
|
assert "policy_setup" in row["artifact_paths"]["cprofile_summaries"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_deterministic_forward_artifacts_preserve_policy_mode(tmp_path):
|
||||||
|
from lerobot.utils.profiling_utils import write_deterministic_forward_artifacts
|
||||||
|
|
||||||
|
class _TrainingOnlyPolicy(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.forward_calls = 0
|
||||||
|
|
||||||
|
def forward(self, batch):
|
||||||
|
self.forward_calls += 1
|
||||||
|
assert self.training
|
||||||
|
return batch["value"].sum(), {"value": batch["value"]}
|
||||||
|
|
||||||
|
dataset = [{"value": torch.tensor([1.0, 2.0])}]
|
||||||
|
policy = _TrainingOnlyPolicy()
|
||||||
|
policy.train()
|
||||||
|
|
||||||
|
write_deterministic_forward_artifacts(
|
||||||
|
policy=policy,
|
||||||
|
dataset=dataset,
|
||||||
|
batch_size=2,
|
||||||
|
preprocessor=lambda batch: batch,
|
||||||
|
output_dir=tmp_path,
|
||||||
|
device_type="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = json.loads((tmp_path / "deterministic_forward.json").read_text())
|
||||||
|
assert policy.training is True
|
||||||
|
assert policy.forward_calls == 1
|
||||||
|
assert payload["reference_batch_size"] == 2
|
||||||
|
assert "operator_fingerprint" in payload
|
||||||
|
assert payload["outputs"]["loss"]["numel"] == 1
|
||||||
|
|||||||
Reference in New Issue
Block a user