feat(profiling): add weekly model profiling

This commit is contained in:
Pepijn
2026-04-15 22:31:44 +02:00
parent bd74f6733d
commit 1a2aec1b04
7 changed files with 1307 additions and 115 deletions
+161
View File
@@ -0,0 +1,161 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
name: Model Profiling
on:
schedule:
- cron: "0 0 * * 0"
pull_request:
branches:
- main
- feat/libero-benchmark
paths:
- .github/workflows/model_profiling.yml
- profiling/model_profiling_specs.json
- scripts/ci/run_model_profiling.py
- src/lerobot/configs/train.py
- src/lerobot/scripts/lerobot_train.py
- src/lerobot/utils/profiling_utils.py
- tests/scripts/test_model_profiling.py
workflow_dispatch:
inputs:
git_ref:
description: Git ref to profile when no commit SHA is provided
required: false
type: string
default: main
git_commit:
description: Optional exact commit SHA to profile
required: false
type: string
default: ""
policies:
description: Optional comma-separated policy filter
required: false
type: string
default: ""
profile_mode:
description: Torch profiler mode
required: false
type: choice
options:
- trace
- summary
default: trace
publish_results:
description: Publish results to the profiling dataset when a Hub token is available
required: false
type: boolean
default: true
results_repo:
description: Dataset repo name or fully qualified repo id
required: false
type: string
default: model-profiling-history
permissions:
contents: read
concurrency:
group: ${{ github.workflow }}-${{ github.event_name }}-${{ github.event.inputs.git_commit || github.event.inputs.git_ref || github.ref_name || github.run_id }}
cancel-in-progress: true
jobs:
profile-models:
name: Weekly Model Profiling
runs-on:
group: aws-g6-4xlarge-plus
env:
HF_USER_TOKEN: ${{ secrets.LEROBOT_HF_USER }}
PROFILE_MODE: ${{ github.event_name == 'pull_request' && 'summary' || github.event.inputs.profile_mode || 'trace' }}
POLICY_FILTER: ${{ github.event_name == 'pull_request' && 'act' || github.event.inputs.policies || '' }}
RESULTS_REPO: ${{ github.event.inputs.results_repo || 'model-profiling-history' }}
SHOULD_PUBLISH: ${{ github.event_name == 'schedule' || (github.event_name == 'workflow_dispatch' && github.event.inputs.publish_results == 'true') }}
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
lfs: true
ref: ${{ github.event.pull_request.head.sha || github.event.inputs.git_commit || github.event.inputs.git_ref || 'main' }}
- name: Pull GPU image
run: docker pull huggingface/lerobot-gpu:latest
- name: Run model profiling
run: |
set -eux
mkdir -p profiling-results
docker run --rm --gpus all \
--shm-size=16g \
-e HF_HOME=/tmp/hf \
-e UV_PROJECT_ENVIRONMENT=/tmp/lerobot-venv \
-e UV_CACHE_DIR=/tmp/uv-cache \
-e XDG_CACHE_HOME=/tmp/xdg-cache \
-e HOST_GIT_COMMIT="${{ github.event.pull_request.head.sha || github.event.inputs.git_commit || github.sha }}"
-e HF_USER_TOKEN="${HF_USER_TOKEN}" \
-e HF_TOKEN="${HF_USER_TOKEN}" \
-e PROFILE_MODE="${PROFILE_MODE}" \
-e POLICY_FILTER="${POLICY_FILTER}" \
-e RESULTS_REPO="${RESULTS_REPO}" \
-e SHOULD_PUBLISH="${SHOULD_PUBLISH}" \
-v "${GITHUB_WORKSPACE}:/workspace" \
-w /workspace \
huggingface/lerobot-gpu:latest \
bash -lc '
set -euxo pipefail
rm -rf /tmp/lerobot-src
cp -a /workspace/. /tmp/lerobot-src
cd /tmp/lerobot-src
if [[ -n "${HF_USER_TOKEN:-}" ]]; then
hf auth login --token "${HF_USER_TOKEN}" --add-to-git-credential 2>/dev/null || true
fi
uv sync --locked --extra all
cmd=(
uv run python scripts/ci/run_model_profiling.py
--output_dir=/workspace/profiling-results
--hub_org=lerobot
--results_repo="${RESULTS_REPO}"
--profile_mode="${PROFILE_MODE}"
--git_commit="${HOST_GIT_COMMIT}"
)
if [[ -n "${POLICY_FILTER}" ]]; then
IFS="," read -ra policies <<< "${POLICY_FILTER}"
cmd+=(--policies)
for policy in "${policies[@]}"; do
policy="$(echo "${policy}" | xargs)"
if [[ -n "${policy}" ]]; then
cmd+=("${policy}")
fi
done
fi
if [[ "${SHOULD_PUBLISH}" == "true" && -n "${HF_USER_TOKEN:-}" ]]; then
cmd+=(--publish)
fi
"${cmd[@]}"
'
- name: Upload profiling artifacts
if: always()
uses: actions/upload-artifact@v4 # zizmor: ignore[unpinned-uses]
with:
name: model-profiling-results
path: profiling-results
if-no-files-found: warn
+128
View File
@@ -0,0 +1,128 @@
{
"act": {
"steps": 12,
"train_args": [
"--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0]",
"--policy.type=act",
"--policy.device=cuda",
"--batch_size=4"
]
},
"diffusion": {
"steps": 12,
"train_args": [
"--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0]",
"--policy.type=diffusion",
"--policy.device=cuda",
"--batch_size=4"
]
},
"groot": {
"steps": 12,
"train_args": [
"--dataset.repo_id=lerobot/libero_plus",
"--dataset.episodes=[0]",
"--policy.type=groot",
"--policy.base_model_path=nvidia/GR00T-N1.5-3B",
"--policy.tune_diffusion_model=true",
"--policy.tune_projector=true",
"--policy.tune_llm=false",
"--policy.tune_visual=false",
"--policy.use_bf16=true",
"--policy.device=cuda",
"--batch_size=1",
"--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.image2\": \"observation.images.camera2\"}"
]
},
"multi_task_dit": {
"steps": 12,
"train_args": [
"--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0]",
"--policy.type=multi_task_dit",
"--policy.device=cuda",
"--policy.horizon=32",
"--policy.n_action_steps=30",
"--batch_size=4"
]
},
"pi0": {
"steps": 12,
"train_args": [
"--dataset.repo_id=lerobot/libero_plus",
"--dataset.episodes=[0]",
"--policy.path=lerobot/pi0_base",
"--policy.device=cuda",
"--policy.n_action_steps=30",
"--batch_size=1",
"--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.image2\": \"observation.images.camera2\"}"
]
},
"pi0_fast": {
"steps": 12,
"train_args": [
"--dataset.repo_id=lerobot/libero_plus",
"--dataset.episodes=[0]",
"--policy.path=lerobot/pi0fast-base",
"--policy.device=cuda",
"--policy.n_action_steps=30",
"--batch_size=1",
"--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.image2\": \"observation.images.camera2\"}"
]
},
"pi05": {
"steps": 12,
"train_args": [
"--dataset.repo_id=lerobot/libero_plus",
"--dataset.episodes=[0]",
"--policy.path=lerobot/pi05_base",
"--policy.device=cuda",
"--policy.n_action_steps=30",
"--batch_size=1",
"--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.image2\": \"observation.images.camera2\"}"
]
},
"smolvla": {
"steps": 12,
"train_args": [
"--dataset.repo_id=lerobot/libero_plus",
"--dataset.episodes=[0]",
"--policy.path=lerobot/smolvla_base",
"--policy.load_vlm_weights=true",
"--policy.freeze_vision_encoder=false",
"--policy.train_expert_only=false",
"--policy.empty_cameras=1",
"--policy.device=cuda",
"--batch_size=1",
"--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.image2\": \"observation.images.camera2\"}"
]
},
"wall_x": {
"steps": 12,
"train_args": [
"--dataset.repo_id=lerobot/aloha_sim_insertion_human",
"--dataset.episodes=[0]",
"--policy.type=wall_x",
"--policy.pretrained_name_or_path=x-square-robot/wall-oss-flow",
"--policy.prediction_mode=diffusion",
"--policy.attn_implementation=eager",
"--policy.device=cuda",
"--batch_size=1"
]
},
"xvla": {
"steps": 12,
"train_args": [
"--dataset.repo_id=lerobot/libero_plus",
"--dataset.episodes=[0]",
"--policy.path=lerobot/xvla-widowx",
"--policy.action_mode=auto",
"--policy.empty_cameras=1",
"--policy.device=cuda",
"--batch_size=1",
"--rename_map={\"observation.images.image\": \"observation.images.camera1\", \"observation.images.image2\": \"observation.images.camera2\"}"
]
}
}
+290
View File
@@ -0,0 +1,290 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import argparse
import json
import subprocess
import time
from dataclasses import dataclass
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from huggingface_hub import HfApi
@dataclass(frozen=True)
class ProfilingSpec:
name: str
steps: int
train_args: list[str]
@dataclass(frozen=True)
class UploadTarget:
local_path: Path
path_in_repo: str
def utc_timestamp_slug(now: datetime | None = None) -> str:
current = now or datetime.now(UTC)
return current.strftime("%Y%m%dT%H%M%SZ")
def make_hub_file_url(repo_id: str, path_in_repo: str, repo_type: str = "dataset") -> str:
prefix = "datasets/" if repo_type == "dataset" else ""
return f"https://huggingface.co/{prefix}{repo_id}/resolve/main/{path_in_repo}"
def upload_targets(
repo_id: str,
targets: list[UploadTarget],
*,
repo_type: str = "dataset",
token: str | None = None,
private: bool | None = None,
commit_message: str | None = None,
) -> dict[str, str]:
api = HfApi(token=token)
api.create_repo(repo_id=repo_id, repo_type=repo_type, private=private, exist_ok=True)
uploaded: dict[str, str] = {}
for target in targets:
api.upload_file(
path_or_fileobj=str(target.local_path),
path_in_repo=target.path_in_repo,
repo_id=repo_id,
repo_type=repo_type,
commit_message=commit_message or f"Upload {target.path_in_repo}",
)
uploaded[target.path_in_repo] = make_hub_file_url(repo_id, target.path_in_repo, repo_type=repo_type)
return uploaded
def normalize_repo_id(repo: str, hub_org: str) -> str:
return repo if "/" in repo else f"{hub_org}/{repo}"
def load_specs(path: Path) -> dict[str, ProfilingSpec]:
payload = json.loads(path.read_text())
return {
name: ProfilingSpec(name=name, steps=spec["steps"], train_args=spec["train_args"])
for name, spec in payload.items()
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--spec-file", type=Path, default=Path("profiling/model_profiling_specs.json"))
parser.add_argument("--policies", nargs="*", default=None)
parser.add_argument("--output_dir", type=Path, required=True)
parser.add_argument("--hub_org", default="lerobot")
parser.add_argument("--results_repo", default="model-profiling-history")
parser.add_argument("--publish", action="store_true")
parser.add_argument("--profile_mode", choices=["summary", "trace"], default="trace")
parser.add_argument("--git_commit", default="")
return parser.parse_args()
def get_selected_names(requested: list[str] | None, specs: dict[str, ProfilingSpec]) -> list[str]:
if not requested:
return list(specs)
unknown = sorted(set(requested) - set(specs))
if unknown:
raise ValueError(f"Unknown profiling policies: {', '.join(unknown)}")
return requested
def build_train_command(spec: ProfilingSpec, run_dir: Path, profile_mode: str) -> list[str]:
train_output_dir = run_dir / "train"
profile_output_dir = run_dir / "profiling"
return [
"uv",
"run",
"lerobot-train",
*spec.train_args,
f"--output_dir={train_output_dir}",
f"--steps={spec.steps}",
"--eval_freq=0",
"--save_checkpoint=false",
f"--save_freq={spec.steps}",
"--wandb.enable=false",
"--num_workers=0",
"--log_freq=1",
"--cudnn_deterministic=true",
f"--profile_mode={profile_mode}",
f"--profile_output_dir={profile_output_dir}",
]
def load_json_if_exists(path: Path) -> dict[str, Any] | None:
if not path.exists():
return None
return json.loads(path.read_text())
def build_artifact_index(
*,
repo_id: str,
run_dir: Path,
policy_name: str,
run_id: str,
) -> tuple[dict[str, Any], dict[str, Any], list[UploadTarget], str]:
row_path_in_repo = f"rows/{policy_name}/{run_id}.json"
artifact_root = f"artifacts/{policy_name}/{run_id}"
artifact_paths: dict[str, Any] = {
"row": row_path_in_repo,
"profiling_files": {},
"cprofile_summaries": {},
"torch_tables": {},
"trace_files": {},
}
artifact_urls: dict[str, Any] = {
"row": make_hub_file_url(repo_id, row_path_in_repo),
"profiling_files": {},
"cprofile_summaries": {},
"torch_tables": {},
"trace_files": {},
}
targets: list[UploadTarget] = []
for name in ("stdout.txt", "stderr.txt"):
path = run_dir / name
if not path.exists():
continue
repo_path = f"{artifact_root}/{name}"
artifact_paths[name.removesuffix(".txt")] = repo_path
artifact_urls[name.removesuffix(".txt")] = make_hub_file_url(repo_id, repo_path)
targets.append(UploadTarget(local_path=path, path_in_repo=repo_path))
profiling_dir = run_dir / "profiling"
for path in sorted(profiling_dir.rglob("*")) if profiling_dir.exists() else []:
if not path.is_file():
continue
relative_path = str(path.relative_to(run_dir))
repo_path = f"{artifact_root}/{relative_path}"
artifact_paths["profiling_files"][relative_path] = repo_path
artifact_urls["profiling_files"][relative_path] = make_hub_file_url(repo_id, repo_path)
targets.append(UploadTarget(local_path=path, path_in_repo=repo_path))
if path.name == "step_timing_summary.json":
artifact_paths["step_timing_summary"] = repo_path
artifact_urls["step_timing_summary"] = make_hub_file_url(repo_id, repo_path)
elif "cprofile" in path.parts:
artifact_paths["cprofile_summaries"][path.stem] = repo_path
artifact_urls["cprofile_summaries"][path.stem] = make_hub_file_url(repo_id, repo_path)
elif "torch_tables" in path.parts:
artifact_paths["torch_tables"][path.name] = repo_path
artifact_urls["torch_tables"][path.name] = make_hub_file_url(repo_id, repo_path)
elif "torch_traces" in path.parts:
artifact_paths["trace_files"][path.name] = repo_path
artifact_urls["trace_files"][path.name] = make_hub_file_url(repo_id, repo_path)
return artifact_paths, artifact_urls, targets, row_path_in_repo
def upload_profile_run(
*,
repo_id: str,
row_path: Path,
row_path_in_repo: str,
artifact_targets: list[UploadTarget],
) -> dict[str, str]:
return upload_targets(
repo_id=repo_id,
targets=[*artifact_targets, UploadTarget(local_path=row_path, path_in_repo=row_path_in_repo)],
repo_type="dataset",
private=False,
commit_message=f"Add model profiling row {row_path_in_repo}",
)
def main() -> int:
args = parse_args()
specs = load_specs(args.spec_file)
selected = get_selected_names(args.policies, specs)
args.output_dir.mkdir(parents=True, exist_ok=True)
repo_id = normalize_repo_id(args.results_repo, args.hub_org)
git_commit = args.git_commit or subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip()
for policy_name in selected:
spec = specs[policy_name]
run_id = f"{utc_timestamp_slug()}__{policy_name}"
run_dir = args.output_dir / policy_name / run_id
run_dir.mkdir(parents=True, exist_ok=True)
cmd = build_train_command(spec, run_dir, args.profile_mode)
start = time.perf_counter()
result = subprocess.run(cmd, capture_output=True, text=True)
duration_s = time.perf_counter() - start
stdout_path = run_dir / "stdout.txt"
stderr_path = run_dir / "stderr.txt"
stdout_path.write_text(result.stdout)
stderr_path.write_text(result.stderr)
profile_summary = load_json_if_exists(run_dir / "profiling" / "step_timing_summary.json") or {}
deterministic_forward = (
load_json_if_exists(run_dir / "profiling" / "deterministic_forward.json") or {}
)
artifact_paths, artifact_urls, artifact_targets, row_path_in_repo = build_artifact_index(
repo_id=repo_id,
run_dir=run_dir,
policy_name=policy_name,
run_id=run_id,
)
row = {
"schema_version": 1,
"created_at": datetime.now(UTC).isoformat(),
"run_id": run_id,
"policy": policy_name,
"git_commit": git_commit,
"status": "success" if result.returncode == 0 else "failed",
"return_code": result.returncode,
"profile_mode": args.profile_mode,
"wall_time_s": duration_s,
"spec": {
"steps": spec.steps,
"train_args": spec.train_args,
},
"step_timing_summary": profile_summary,
"deterministic_forward": deterministic_forward,
"artifact_paths": artifact_paths,
"artifact_urls": artifact_urls,
"stderr_tail": result.stderr.splitlines()[-20:],
}
row_path = run_dir / "profiling_row.json"
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
if args.publish:
uploaded_paths = upload_profile_run(
repo_id=repo_id,
row_path=row_path,
row_path_in_repo=row_path_in_repo,
artifact_targets=artifact_targets,
)
row["uploaded_paths"] = uploaded_paths
row_path.write_text(json.dumps(row, indent=2, sort_keys=True))
print(json.dumps(row, indent=2, sort_keys=True))
return 0
if __name__ == "__main__":
raise SystemExit(main())
+20
View File
@@ -56,6 +56,16 @@ class TrainPipelineConfig(HubMixin):
# Number of workers for the dataloader. # Number of workers for the dataloader.
num_workers: int = 4 num_workers: int = 4
batch_size: int = 8 batch_size: int = 8
profile_mode: str = "off"
profile_wait_steps: int = 1
profile_warmup_steps: int = 2
profile_active_steps: int = 6
profile_repeat: int = 1
profile_output_dir: Path | None = None
profile_record_shapes: bool = True
profile_with_memory: bool = True
profile_with_flops: bool = True
profile_with_stack: bool = False
steps: int = 100_000 steps: int = 100_000
eval_freq: int = 20_000 eval_freq: int = 20_000
log_freq: int = 200 log_freq: int = 200
@@ -128,9 +138,19 @@ class TrainPipelineConfig(HubMixin):
now = dt.datetime.now() now = dt.datetime.now()
train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}" train_dir = f"{now:%Y-%m-%d}/{now:%H-%M-%S}_{self.job_name}"
self.output_dir = Path("outputs/train") / train_dir self.output_dir = Path("outputs/train") / train_dir
if self.profile_mode != "off" and self.profile_output_dir is None:
self.profile_output_dir = self.output_dir / "profiling"
if isinstance(self.dataset.repo_id, list): if isinstance(self.dataset.repo_id, list):
raise NotImplementedError("LeRobotMultiDataset is not currently implemented.") raise NotImplementedError("LeRobotMultiDataset is not currently implemented.")
if self.profile_mode not in {"off", "summary", "trace"}:
raise ValueError(
f"`profile_mode` must be one of 'off', 'summary', or 'trace', got {self.profile_mode}."
)
if self.profile_wait_steps < 0 or self.profile_warmup_steps < 0 or self.profile_active_steps < 0:
raise ValueError("Profiler schedule steps must be non-negative.")
if self.profile_repeat <= 0:
raise ValueError("`profile_repeat` must be strictly positive.")
if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None): if not self.use_policy_training_preset and (self.optimizer is None or self.scheduler is None):
raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.") raise ValueError("Optimizer and Scheduler must be set when the policy presets are not used.")
+225 -115
View File
@@ -22,6 +22,7 @@ import dataclasses
import logging import logging
import time import time
from contextlib import nullcontext from contextlib import nullcontext
from pathlib import Path
from pprint import pformat from pprint import pformat
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@@ -49,6 +50,14 @@ from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.profiling_utils import (
StepTimingCollector,
ensure_dir,
make_torch_profiler,
run_with_cprofile,
write_deterministic_forward_artifacts,
write_torch_profiler_outputs,
)
from lerobot.utils.random_utils import set_seed from lerobot.utils.random_utils import set_seed
from lerobot.utils.utils import ( from lerobot.utils.utils import (
cycle, cycle,
@@ -71,6 +80,7 @@ def update_policy(
lr_scheduler=None, lr_scheduler=None,
lock=None, lock=None,
rabc_weights_provider=None, rabc_weights_provider=None,
timing_collector: StepTimingCollector | None = None,
) -> tuple[MetricsTracker, dict]: ) -> tuple[MetricsTracker, dict]:
""" """
Performs a single training step to update the policy's weights. Performs a single training step to update the policy's weights.
@@ -104,6 +114,7 @@ def update_policy(
rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch) rabc_batch_weights, rabc_batch_stats = rabc_weights_provider.compute_batch_weights(batch)
# Let accelerator handle mixed precision # Let accelerator handle mixed precision
forward_start = time.perf_counter()
with accelerator.autocast(): with accelerator.autocast():
# Use per-sample loss when RA-BC is enabled for proper weighting # Use per-sample loss when RA-BC is enabled for proper weighting
if rabc_batch_weights is not None: if rabc_batch_weights is not None:
@@ -122,11 +133,15 @@ def update_policy(
loss, output_dict = policy.forward(batch) loss, output_dict = policy.forward(batch)
# TODO(rcadene): policy.unnormalize_outputs(out_dict) # TODO(rcadene): policy.unnormalize_outputs(out_dict)
forward_s = time.perf_counter() - forward_start
# Use accelerator's backward method # Use accelerator's backward method
backward_start = time.perf_counter()
accelerator.backward(loss) accelerator.backward(loss)
backward_s = time.perf_counter() - backward_start
# Clip gradients if specified # Clip gradients if specified
optimizer_start = time.perf_counter()
if grad_clip_norm > 0: if grad_clip_norm > 0:
grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm) grad_norm = accelerator.clip_grad_norm_(policy.parameters(), grad_clip_norm)
else: else:
@@ -147,11 +162,19 @@ def update_policy(
# Update internal buffers if policy has update method # Update internal buffers if policy has update method
if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"): if has_method(accelerator.unwrap_model(policy, keep_fp32_wrapper=True), "update"):
accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update() accelerator.unwrap_model(policy, keep_fp32_wrapper=True).update()
optimizer_s = time.perf_counter() - optimizer_start
train_metrics.loss = loss.item() train_metrics.loss = loss.item()
train_metrics.grad_norm = grad_norm.item() train_metrics.grad_norm = grad_norm.item()
train_metrics.lr = optimizer.param_groups[0]["lr"] train_metrics.lr = optimizer.param_groups[0]["lr"]
train_metrics.update_s = time.perf_counter() - start_time train_metrics.update_s = time.perf_counter() - start_time
if timing_collector is not None:
timing_collector.record(
forward_s=forward_s,
backward_s=backward_s,
optimizer_s=optimizer_s,
total_update_s=train_metrics.update_s,
)
return train_metrics, output_dict return train_metrics, output_dict
@@ -206,6 +229,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if is_main_process: if is_main_process:
logging.info(pformat(cfg.to_dict())) logging.info(pformat(cfg.to_dict()))
profiling_enabled = cfg.profile_mode != "off"
profile_output_dir = None
cprofile_dir = None
if profiling_enabled and is_main_process and cfg.profile_output_dir is not None:
profile_output_dir = ensure_dir(Path(cfg.profile_output_dir))
cprofile_dir = ensure_dir(profile_output_dir / "cprofile")
logging.info("Profiling enabled. Artifacts will be written to %s", profile_output_dir)
# Initialize wandb only on main process # Initialize wandb only on main process
if cfg.wandb.enable and cfg.wandb.project and is_main_process: if cfg.wandb.enable and cfg.wandb.project and is_main_process:
wandb_logger = WandBLogger(cfg) wandb_logger = WandBLogger(cfg)
@@ -229,7 +260,10 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# Dataset loading synchronization: main process downloads first to avoid race conditions # Dataset loading synchronization: main process downloads first to avoid race conditions
if is_main_process: if is_main_process:
logging.info("Creating dataset") logging.info("Creating dataset")
dataset = make_dataset(cfg) if cprofile_dir is not None:
dataset = run_with_cprofile("dataset_setup", cprofile_dir, make_dataset, cfg)
else:
dataset = make_dataset(cfg)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
@@ -247,11 +281,21 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if is_main_process: if is_main_process:
logging.info("Creating policy") logging.info("Creating policy")
policy = make_policy( if is_main_process and cprofile_dir is not None:
cfg=cfg.policy, policy = run_with_cprofile(
ds_meta=dataset.meta, "policy_setup",
rename_map=cfg.rename_map, cprofile_dir,
) make_policy,
cfg=cfg.policy,
ds_meta=dataset.meta,
rename_map=cfg.rename_map,
)
else:
policy = make_policy(
cfg=cfg.policy,
ds_meta=dataset.meta,
rename_map=cfg.rename_map,
)
if cfg.peft is not None: if cfg.peft is not None:
logging.info("Using PEFT! Wrapping model.") logging.info("Using PEFT! Wrapping model.")
@@ -305,16 +349,47 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
}, },
} }
preprocessor, postprocessor = make_pre_post_processors( if is_main_process and cprofile_dir is not None:
policy_cfg=cfg.policy, preprocessor, postprocessor = run_with_cprofile(
pretrained_path=processor_pretrained_path, "processor_setup",
**processor_kwargs, cprofile_dir,
**postprocessor_kwargs, make_pre_post_processors,
) policy_cfg=cfg.policy,
pretrained_path=processor_pretrained_path,
**processor_kwargs,
**postprocessor_kwargs,
)
else:
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=processor_pretrained_path,
**processor_kwargs,
**postprocessor_kwargs,
)
if is_main_process: if is_main_process:
logging.info("Creating optimizer and scheduler") logging.info("Creating optimizer and scheduler")
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy) if is_main_process and cprofile_dir is not None:
optimizer, lr_scheduler = run_with_cprofile(
"optimizer_setup",
cprofile_dir,
make_optimizer_and_scheduler,
cfg,
policy,
)
else:
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
if profiling_enabled and is_main_process and profile_output_dir is not None:
logging.info("Recording deterministic forward-pass artifacts")
write_deterministic_forward_artifacts(
policy=policy,
dataset=dataset,
batch_size=cfg.batch_size,
preprocessor=preprocessor,
output_dir=profile_output_dir,
device_type=device.type,
)
# Load precomputed SARM progress for RA-BC if enabled # Load precomputed SARM progress for RA-BC if enabled
# Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py # Generate progress using: src/lerobot/policies/sarm/compute_rabc_weights.py
@@ -429,124 +504,159 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
logging.info( logging.info(
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}" f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
) )
timing_collector = StepTimingCollector() if profiling_enabled and is_main_process else None
profiler = None
profiler_context = nullcontext()
if profiling_enabled and is_main_process and profile_output_dir is not None:
if device.type == "cuda":
torch.cuda.reset_peak_memory_stats(device)
profiler = make_torch_profiler(cfg, profile_output_dir, device.type)
profiler_context = profiler
for _ in range(step, cfg.steps): with profiler_context:
start_time = time.perf_counter() for _ in range(step, cfg.steps):
batch = next(dl_iter) start_time = time.perf_counter()
batch = preprocessor(batch) batch = next(dl_iter)
train_tracker.dataloading_s = time.perf_counter() - start_time batch = preprocessor(batch)
train_tracker.dataloading_s = time.perf_counter() - start_time
train_tracker, output_dict = update_policy( train_tracker, output_dict = update_policy(
train_tracker, train_tracker,
policy, policy,
batch, batch,
optimizer, optimizer,
cfg.optimizer.grad_clip_norm, cfg.optimizer.grad_clip_norm,
accelerator=accelerator, accelerator=accelerator,
lr_scheduler=lr_scheduler, lr_scheduler=lr_scheduler,
rabc_weights_provider=rabc_weights, rabc_weights_provider=rabc_weights,
) timing_collector=timing_collector,
)
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here. # increment `step` here.
step += 1 step += 1
if is_main_process:
progbar.update(1)
train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
if is_log_step:
logging.info(train_tracker)
if wandb_logger:
wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
# Log RA-BC statistics if enabled
if rabc_weights is not None:
rabc_stats = rabc_weights.get_stats()
wandb_log_dict.update(
{
"rabc_delta_mean": rabc_stats["delta_mean"],
"rabc_delta_std": rabc_stats["delta_std"],
"rabc_num_frames": rabc_stats["num_frames"],
}
)
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
if cfg.save_checkpoint and is_saving_step:
if is_main_process: if is_main_process:
logging.info(f"Checkpoint policy after step {step}") progbar.update(1)
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) if timing_collector is not None:
save_checkpoint( timing_collector.record_dataloading(train_tracker.dataloading_s.val)
checkpoint_dir=checkpoint_dir, if device.type == "cuda":
step=step, timing_collector.record_memory(
cfg=cfg, step=step,
policy=accelerator.unwrap_model(policy), allocated_bytes=torch.cuda.memory_allocated(device),
optimizer=optimizer, reserved_bytes=torch.cuda.memory_reserved(device),
scheduler=lr_scheduler, )
preprocessor=preprocessor, train_tracker.step()
postprocessor=postprocessor, if profiler is not None:
) profiler.step()
update_last_checkpoint(checkpoint_dir) is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
if is_log_step:
logging.info(train_tracker)
if wandb_logger: if wandb_logger:
wandb_logger.log_policy(checkpoint_dir) wandb_log_dict = train_tracker.to_dict()
if output_dict:
wandb_log_dict.update(output_dict)
# Log RA-BC statistics if enabled
if rabc_weights is not None:
rabc_stats = rabc_weights.get_stats()
wandb_log_dict.update(
{
"rabc_delta_mean": rabc_stats["delta_mean"],
"rabc_delta_std": rabc_stats["delta_std"],
"rabc_num_frames": rabc_stats["num_frames"],
}
)
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
accelerator.wait_for_everyone() if cfg.save_checkpoint and is_saving_step:
if is_main_process:
if cfg.env and is_eval_step: logging.info(f"Checkpoint policy after step {step}")
if is_main_process: checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
step_id = get_step_identifier(step, cfg.steps) save_checkpoint(
logging.info(f"Eval policy at step {step}") checkpoint_dir=checkpoint_dir,
with torch.no_grad(), accelerator.autocast(): step=step,
eval_info = eval_policy_all( cfg=cfg,
envs=eval_env, # dict[suite][task_id] -> vec_env
policy=accelerator.unwrap_model(policy), policy=accelerator.unwrap_model(policy),
env_preprocessor=env_preprocessor, optimizer=optimizer,
env_postprocessor=env_postprocessor, scheduler=lr_scheduler,
preprocessor=preprocessor, preprocessor=preprocessor,
postprocessor=postprocessor, postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
max_episodes_rendered=4,
start_seed=cfg.seed,
max_parallel_tasks=cfg.env.max_parallel_tasks,
) )
# overall metrics (suite-agnostic) update_last_checkpoint(checkpoint_dir)
aggregated = eval_info["overall"] if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
# optional: per-suite logging accelerator.wait_for_everyone()
for suite, suite_info in eval_info.items():
logging.info("Suite %s aggregated: %s", suite, suite_info)
# meters/tracker if cfg.env and is_eval_step:
eval_metrics = { if is_main_process:
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"), step_id = get_step_identifier(step, cfg.steps)
"pc_success": AverageMeter("success", ":.1f"), logging.info(f"Eval policy at step {step}")
"eval_s": AverageMeter("eval_s", ":.3f"), with torch.no_grad(), accelerator.autocast():
} eval_info = eval_policy_all(
eval_tracker = MetricsTracker( envs=eval_env, # dict[suite][task_id] -> vec_env
cfg.batch_size, policy=accelerator.unwrap_model(policy),
dataset.num_frames, env_preprocessor=env_preprocessor,
dataset.num_episodes, env_postprocessor=env_postprocessor,
eval_metrics, preprocessor=preprocessor,
initial_step=step, postprocessor=postprocessor,
accelerator=accelerator, n_episodes=cfg.eval.n_episodes,
) videos_dir=cfg.output_dir / "eval" / f"videos_step_{step_id}",
eval_tracker.eval_s = aggregated.pop("eval_s") max_episodes_rendered=4,
eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward") start_seed=cfg.seed,
eval_tracker.pc_success = aggregated.pop("pc_success") max_parallel_tasks=cfg.env.max_parallel_tasks,
if wandb_logger: )
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info} # overall metrics (suite-agnostic)
wandb_logger.log_dict(wandb_log_dict, step, mode="eval") aggregated = eval_info["overall"]
wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
accelerator.wait_for_everyone() # optional: per-suite logging
for suite, suite_info in eval_info.items():
logging.info("Suite %s aggregated: %s", suite, suite_info)
# meters/tracker
eval_metrics = {
"avg_sum_reward": AverageMeter("∑rwrd", ":.3f"),
"pc_success": AverageMeter("success", ":.1f"),
"eval_s": AverageMeter("eval_s", ":.3f"),
}
eval_tracker = MetricsTracker(
cfg.batch_size,
dataset.num_frames,
dataset.num_episodes,
eval_metrics,
initial_step=step,
accelerator=accelerator,
)
eval_tracker.eval_s = aggregated.pop("eval_s")
eval_tracker.avg_sum_reward = aggregated.pop("avg_sum_reward")
eval_tracker.pc_success = aggregated.pop("pc_success")
if wandb_logger:
wandb_log_dict = {**eval_tracker.to_dict(), **eval_info}
wandb_logger.log_dict(wandb_log_dict, step, mode="eval")
wandb_logger.log_video(eval_info["overall"]["video_paths"][0], step, mode="eval")
accelerator.wait_for_everyone()
if is_main_process: if is_main_process:
progbar.close() progbar.close()
if timing_collector is not None and profile_output_dir is not None:
extra_profile_metrics = {
"profile_mode": cfg.profile_mode,
"peak_memory_allocated_bytes": (
torch.cuda.max_memory_allocated(device) if device.type == "cuda" else None
),
"peak_memory_reserved_bytes": (
torch.cuda.max_memory_reserved(device) if device.type == "cuda" else None
),
}
timing_collector.write_json(
profile_output_dir / "step_timing_summary.json", extra=extra_profile_metrics
)
if profiler is not None and profile_output_dir is not None:
write_torch_profiler_outputs(profiler, profile_output_dir, device_type=device.type)
if eval_env: if eval_env:
close_envs(eval_env) close_envs(eval_env)
+297
View File
@@ -0,0 +1,297 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import cProfile
import hashlib
import io
import json
import pstats
import statistics
from collections.abc import Callable
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
import torch
from torch.utils.data._utils.collate import default_collate
def ensure_dir(path: Path) -> Path:
path.mkdir(parents=True, exist_ok=True)
return path
def render_cprofile_summary(
profile: cProfile.Profile, *, sort_by: str = "cumulative", limit: int = 40
) -> str:
output = io.StringIO()
stats = pstats.Stats(profile, stream=output).strip_dirs().sort_stats(sort_by)
stats.print_stats(limit)
return output.getvalue()
def write_profiler_table(
profiler: Any,
output_path: Path,
*,
sort_by: str,
row_limit: int = 40,
) -> None:
try:
table = profiler.key_averages().table(sort_by=sort_by, row_limit=row_limit)
except Exception:
return
output_path.write_text(table)
def make_torch_profiler(cfg: Any, output_dir: Path, device_type: str) -> Any:
activities = [torch.profiler.ProfilerActivity.CPU]
if device_type == "cuda":
activities.append(torch.profiler.ProfilerActivity.CUDA)
trace_dir = ensure_dir(output_dir / "torch_traces")
def _trace_ready(profiler: Any) -> None:
if cfg.profile_mode != "trace":
return
profiler.export_chrome_trace(str(trace_dir / f"trace_step_{profiler.step_num}.json"))
return torch.profiler.profile(
activities=activities,
schedule=torch.profiler.schedule(
wait=cfg.profile_wait_steps,
warmup=cfg.profile_warmup_steps,
active=cfg.profile_active_steps,
repeat=cfg.profile_repeat,
),
on_trace_ready=_trace_ready,
record_shapes=cfg.profile_record_shapes,
profile_memory=cfg.profile_with_memory,
with_flops=cfg.profile_with_flops,
with_stack=cfg.profile_with_stack,
)
def write_torch_profiler_outputs(
profiler: Any,
output_dir: Path,
*,
device_type: str,
) -> None:
tables_dir = ensure_dir(output_dir / "torch_tables")
write_profiler_table(profiler, tables_dir / "cpu_time_total.txt", sort_by="cpu_time_total")
if device_type == "cuda":
write_profiler_table(profiler, tables_dir / "cuda_time_total.txt", sort_by="self_cuda_time_total")
write_profiler_table(profiler, tables_dir / "cuda_memory.txt", sort_by="self_cuda_memory_usage")
write_profiler_table(profiler, tables_dir / "cpu_memory.txt", sort_by="self_cpu_memory_usage")
write_profiler_table(profiler, tables_dir / "flops.txt", sort_by="flops")
def run_with_cprofile[T](
label: str,
output_dir: Path,
fn: Callable[..., T],
*args: Any,
sort_by: str = "cumulative",
limit: int = 40,
**kwargs: Any,
) -> T:
ensure_dir(output_dir)
profile = cProfile.Profile()
profile.enable()
try:
return fn(*args, **kwargs)
finally:
profile.disable()
summary = render_cprofile_summary(profile, sort_by=sort_by, limit=limit)
(output_dir / f"{label}.txt").write_text(summary)
def _stable_float(value: float | int | None) -> float | None:
if value is None:
return None
return round(float(value), 8)
def _tensor_signature(tensor: torch.Tensor) -> dict[str, Any]:
cpu_tensor = tensor.detach().cpu()
if cpu_tensor.numel() == 0:
stats = {"sum": None, "mean": None, "std": None, "min": None, "max": None}
else:
stats_tensor = (
cpu_tensor.to(torch.float64) if cpu_tensor.is_floating_point() else cpu_tensor.to(torch.int64)
)
stats = {
"sum": _stable_float(stats_tensor.sum().item()),
"mean": _stable_float(stats_tensor.float().mean().item()),
"std": _stable_float(stats_tensor.float().std(unbiased=False).item())
if cpu_tensor.numel() > 1
else 0.0,
"min": _stable_float(stats_tensor.min().item()),
"max": _stable_float(stats_tensor.max().item()),
}
hash_tensor = cpu_tensor.float() if cpu_tensor.dtype == torch.bfloat16 else cpu_tensor
digest = hashlib.sha256(hash_tensor.contiguous().numpy().tobytes()).hexdigest()
return {
"shape": list(cpu_tensor.shape),
"dtype": str(cpu_tensor.dtype),
"numel": cpu_tensor.numel(),
"sha256": digest,
**stats,
}
def _summarize_forward_value(value: Any) -> Any:
if isinstance(value, torch.Tensor):
return _tensor_signature(value)
if isinstance(value, dict):
return {key: _summarize_forward_value(val) for key, val in value.items()}
if isinstance(value, (list, tuple)):
return [_summarize_forward_value(item) for item in value]
if isinstance(value, (str, int, float, bool)) or value is None:
return value
return repr(value)
def _hash_payload(payload: Any) -> str:
return hashlib.sha256(json.dumps(payload, sort_keys=True).encode()).hexdigest()
def _build_reference_batch(dataset: Any, batch_size: int) -> Any:
if len(dataset) == 0:
raise ValueError("Cannot build a reference batch from an empty dataset.")
indices = [idx % len(dataset) for idx in range(batch_size)]
samples = [dataset[idx] for idx in indices]
return default_collate(samples)
def write_deterministic_forward_artifacts(
*,
policy: Any,
dataset: Any,
batch_size: int,
preprocessor: Any,
output_dir: Path,
device_type: str,
) -> None:
reference_batch = preprocessor(_build_reference_batch(dataset, batch_size))
activities = [torch.profiler.ProfilerActivity.CPU]
if device_type == "cuda":
activities.append(torch.profiler.ProfilerActivity.CUDA)
was_training = policy.training
policy.eval()
with torch.random.fork_rng(devices=[] if device_type != "cuda" else None):
torch.manual_seed(0)
if device_type == "cuda":
torch.cuda.manual_seed_all(0)
with torch.no_grad(), torch.profiler.profile(activities=activities) as profiler:
loss, output_dict = policy.forward(reference_batch)
if was_training:
policy.train()
operator_entries = []
for event in profiler.key_averages():
entry = {
"key": event.key,
"count": event.count,
"cpu_time_total_us": _stable_float(getattr(event, "cpu_time_total", None)),
}
if device_type == "cuda":
entry["self_cuda_time_total_us"] = _stable_float(getattr(event, "self_cuda_time_total", None))
operator_entries.append(entry)
operator_entries = sorted(operator_entries, key=lambda item: item["key"])
output_summary = {
"loss": _summarize_forward_value(loss),
"output_dict": _summarize_forward_value(output_dict),
}
payload = {
"seed": 0,
"reference_batch_size": batch_size,
"operator_fingerprint": _hash_payload([(entry["key"], entry["count"]) for entry in operator_entries]),
"output_fingerprint": _hash_payload(output_summary),
"operators": operator_entries,
"outputs": output_summary,
}
(output_dir / "deterministic_forward.json").write_text(json.dumps(payload, indent=2, sort_keys=True))
table_sort = "self_cuda_time_total" if device_type == "cuda" else "cpu_time_total"
write_profiler_table(profiler, output_dir / "deterministic_forward_ops.txt", sort_by=table_sort)
def _summary(values: list[float]) -> dict[str, float] | dict[str, None]:
if not values:
return {"count": 0, "mean": None, "median": None, "min": None, "max": None}
return {
"count": len(values),
"mean": statistics.fmean(values),
"median": statistics.median(values),
"min": min(values),
"max": max(values),
}
@dataclass
class StepTimingCollector:
forward_s: list[float] = field(default_factory=list)
backward_s: list[float] = field(default_factory=list)
optimizer_s: list[float] = field(default_factory=list)
total_update_s: list[float] = field(default_factory=list)
dataloading_s: list[float] = field(default_factory=list)
memory_timeline: list[dict[str, float | int]] = field(default_factory=list)
def record(
self,
*,
forward_s: float,
backward_s: float,
optimizer_s: float,
total_update_s: float,
) -> None:
self.forward_s.append(forward_s)
self.backward_s.append(backward_s)
self.optimizer_s.append(optimizer_s)
self.total_update_s.append(total_update_s)
def record_dataloading(self, dataloading_s: float) -> None:
self.dataloading_s.append(dataloading_s)
def record_memory(self, *, step: int, allocated_bytes: int, reserved_bytes: int) -> None:
self.memory_timeline.append(
{
"step": step,
"allocated_bytes": allocated_bytes,
"reserved_bytes": reserved_bytes,
}
)
def to_dict(self) -> dict[str, Any]:
return {
"forward_s": _summary(self.forward_s),
"backward_s": _summary(self.backward_s),
"optimizer_s": _summary(self.optimizer_s),
"total_update_s": _summary(self.total_update_s),
"dataloading_s": _summary(self.dataloading_s),
"memory_timeline": self.memory_timeline,
}
def write_json(self, output_path: Path, extra: dict[str, Any] | None = None) -> None:
payload = self.to_dict()
if extra:
payload.update(extra)
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(json.dumps(payload, indent=2, sort_keys=True))
+186
View File
@@ -0,0 +1,186 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import argparse
import importlib.util
import json
import subprocess
import sys
from pathlib import Path
def _import_model_profiling_script():
script_path = Path(__file__).resolve().parents[2] / "scripts" / "ci" / "run_model_profiling.py"
module_name = "tests.scripts.run_model_profiling"
spec = importlib.util.spec_from_file_location(module_name, script_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
assert spec.loader is not None
spec.loader.exec_module(module)
return module
def test_profiling_specs_cover_expected_policies():
module = _import_model_profiling_script()
spec_path = Path(__file__).resolve().parents[2] / "profiling" / "model_profiling_specs.json"
specs = module.load_specs(spec_path)
assert set(specs) == {
"act",
"diffusion",
"groot",
"multi_task_dit",
"pi0",
"pi0_fast",
"pi05",
"smolvla",
"wall_x",
"xvla",
}
for excluded in ("sac", "sarm", "tdmpc", "vqbet", "reward_classifier"):
assert excluded not in specs
def test_build_train_command_includes_profiling_outputs(tmp_path):
module = _import_model_profiling_script()
spec_path = Path(__file__).resolve().parents[2] / "profiling" / "model_profiling_specs.json"
spec = module.load_specs(spec_path)["act"]
cmd = module.build_train_command(spec, tmp_path / "run", "trace")
assert cmd[:3] == ["uv", "run", "lerobot-train"]
assert any(arg.startswith("--output_dir=") for arg in cmd)
assert any(arg.startswith("--profile_output_dir=") for arg in cmd)
assert "--profile_mode=trace" in cmd
assert "--eval_freq=0" in cmd
assert "--cudnn_deterministic=true" in cmd
def test_build_artifact_index_collects_cprofile_tables_and_traces(tmp_path):
module = _import_model_profiling_script()
run_dir = tmp_path / "act" / "20260415T000000Z__act"
profiling_dir = run_dir / "profiling"
(profiling_dir / "cprofile").mkdir(parents=True, exist_ok=True)
(profiling_dir / "torch_tables").mkdir(parents=True, exist_ok=True)
(profiling_dir / "torch_traces").mkdir(parents=True, exist_ok=True)
(profiling_dir / "step_timing_summary.json").write_text("{}")
(profiling_dir / "deterministic_forward.json").write_text(
json.dumps({"operator_fingerprint": "ops123", "output_fingerprint": "out123"})
)
(profiling_dir / "cprofile" / "policy_setup.txt").write_text("policy setup")
(profiling_dir / "torch_tables" / "cpu_time_total.txt").write_text("cpu table")
(profiling_dir / "torch_traces" / "trace_step_9.json").write_text("{}")
(run_dir / "stdout.txt").write_text("stdout")
(run_dir / "stderr.txt").write_text("stderr")
artifact_paths, artifact_urls, targets, row_path_in_repo = module.build_artifact_index(
repo_id="lerobot/model-profiling-history",
run_dir=run_dir,
policy_name="act",
run_id="20260415T000000Z__act",
)
assert row_path_in_repo == "rows/act/20260415T000000Z__act.json"
assert artifact_paths["stdout"].endswith("/stdout.txt")
assert artifact_paths["step_timing_summary"].endswith("/profiling/step_timing_summary.json")
assert "policy_setup" in artifact_paths["cprofile_summaries"]
assert "cpu_time_total.txt" in artifact_paths["torch_tables"]
assert "trace_step_9.json" in artifact_paths["trace_files"]
assert artifact_paths["profiling_files"]["profiling/deterministic_forward.json"].endswith(
"/profiling/deterministic_forward.json"
)
assert artifact_urls["row"].startswith("https://huggingface.co/datasets/lerobot/model-profiling-history/")
assert len(targets) == 7
def test_model_profiling_main_smoke_writes_row(monkeypatch, tmp_path):
module = _import_model_profiling_script()
spec_file = tmp_path / "specs.json"
spec_file.write_text(
json.dumps(
{
"act": {
"steps": 4,
"train_args": [
"--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0]",
"--policy.type=act",
"--policy.device=cuda",
"--batch_size=4",
],
}
}
)
)
args = argparse.Namespace(
spec_file=spec_file,
policies=["act"],
output_dir=tmp_path / "results",
hub_org="lerobot",
results_repo="model-profiling-history",
publish=False,
profile_mode="summary",
git_commit="",
)
monkeypatch.setattr(module, "parse_args", lambda: args)
monkeypatch.setattr(module.subprocess, "check_output", lambda *a, **k: "deadbeef\n")
def _fake_run(cmd, capture_output, text):
assert capture_output is True
assert text is True
profile_dir = Path(
next(arg.split("=", 1)[1] for arg in cmd if arg.startswith("--profile_output_dir="))
)
(profile_dir / "cprofile").mkdir(parents=True, exist_ok=True)
(profile_dir / "torch_tables").mkdir(parents=True, exist_ok=True)
(profile_dir / "step_timing_summary.json").write_text(
json.dumps(
{
"forward_s": {"count": 1, "mean": 0.1, "median": 0.1, "min": 0.1, "max": 0.1},
"total_update_s": {"count": 1, "mean": 0.3, "median": 0.3, "min": 0.3, "max": 0.3},
"peak_memory_allocated_bytes": 1024,
}
)
)
(profile_dir / "deterministic_forward.json").write_text(
json.dumps(
{
"operator_fingerprint": "ops-fingerprint",
"output_fingerprint": "output-fingerprint",
}
)
)
(profile_dir / "cprofile" / "policy_setup.txt").write_text("policy setup profile")
(profile_dir / "torch_tables" / "cpu_time_total.txt").write_text("cpu time table")
return subprocess.CompletedProcess(cmd, 0, "stdout ok", "")
monkeypatch.setattr(module.subprocess, "run", _fake_run)
assert module.main() == 0
row_paths = list((tmp_path / "results").rglob("profiling_row.json"))
assert len(row_paths) == 1
row = json.loads(row_paths[0].read_text())
assert row["policy"] == "act"
assert row["status"] == "success"
assert row["git_commit"] == "deadbeef"
assert row["step_timing_summary"]["forward_s"]["mean"] == 0.1
assert row["deterministic_forward"]["operator_fingerprint"] == "ops-fingerprint"
assert "policy_setup" in row["artifact_paths"]["cprofile_summaries"]