feat(jobs): resume a run on HF Jobs from a checkpoint

When --resume is set with a remote --job.target, submit_to_hf resumes from the
checkpoint repo instead of staging a fresh config. A Hub config_path is resumed
in place (its checkpoint config already targets that repo); a local config_path
has its checkpoint uploaded to a new private repo first and the run is forced to
push back to it. The pod command carries --job.target=local so the checkpoint's
saved job.target can't make the pod re-dispatch itself, and the user's CLI
overrides are forwarded so a remote resume matches the same local command.
ensure_dataset_available is hoisted before the resume/fresh branch since it
applies to both.
This commit is contained in:
Nicolas Rabault
2026-06-24 10:16:03 +02:00
parent 838ab9e234
commit 651c113cd3
+89 -15
View File
@@ -26,6 +26,7 @@ import netrc
import os
import re
import signal
import sys
import tempfile
import threading
from pathlib import Path
@@ -42,6 +43,8 @@ from huggingface_hub import (
upload_file,
)
from lerobot.common.train_utils import push_checkpoint_to_hub
from lerobot.configs import parser
from lerobot.jobs.dataset import ensure_dataset_available
if TYPE_CHECKING:
@@ -218,12 +221,73 @@ def _poll_until_done(
return None
def _pod_forwarded_args(
argv: list[str], drop_names: tuple[str, ...] = (), drop_prefixes: tuple[str, ...] = ()
) -> list[str]:
"""User CLI overrides to replay on the pod, minus flags the submitter sets itself.
Handles both `--name=value` and `--name value` forms. Forwarding the user's overrides (e.g.
`--steps`, `--save_checkpoint_to_hub`) makes a remote resume behave like the same local command.
"""
out: list[str] = []
skip_next = False
for i, tok in enumerate(argv):
if skip_next:
skip_next = False
continue
name = tok.split("=", 1)[0]
if name in drop_names or any(name.startswith(p) for p in drop_prefixes):
if "=" not in tok and i + 1 < len(argv) and not argv[i + 1].startswith("--"):
skip_next = True # also drop the space-separated value
continue
out.append(tok)
return out
def _build_resume_job(cfg: TrainPipelineConfig, username: str) -> tuple[str, list[str]]:
"""Resolve the model repo and pod command to resume a run on a job.
A Hub `config_path` is resumed from directly: its checkpoint config already targets that repo,
so new checkpoints continue the lineage there. A local `config_path` has its checkpoint uploaded
to a new PRIVATE repo first, and the resumed run is forced to push back to it. The pod command
always carries `--job.target=local` so the checkpoint's saved `job.target` can't make the pod
re-dispatch itself.
"""
config_path = parser.parse_arg("config_path")
forwarded = _pod_forwarded_args(
sys.argv[1:],
drop_names=("--config_path", "--policy.repo_id", "--policy.push_to_hub"),
drop_prefixes=("--job.",),
)
if Path(config_path).exists():
# Local checkpoint: stage it on the Hub so the pod can resume from it, and push back there.
# Resolve so a `last` symlink uploads under its real step name (digit), which the pod's
# latest-checkpoint lookup keys on.
checkpoint_dir = Path(cfg.checkpoint_path).resolve()
source_repo = build_repo_id(username, cfg.job_name or "train", dt.datetime.now(dt.UTC))
push_checkpoint_to_hub(checkpoint_dir, source_repo, private=True)
extra = [f"--policy.repo_id={source_repo}", "--policy.push_to_hub=true"]
else:
source_repo = config_path
extra = []
command = [
"lerobot-train",
*forwarded,
f"--config_path={source_repo}",
"--job.target=local",
*extra,
]
return source_repo, command
def submit_to_hf(cfg: TrainPipelineConfig) -> None:
"""Submit a training job to HF Jobs infrastructure.
Validates cfg, resolves credentials, stages the config on the Hub, submits
the job, then either tails logs until completion or detaches immediately.
Ctrl-C detaches without cancelling the remote job.
Validates cfg, resolves credentials, ensures the dataset is on the Hub, then either stages a
sanitized config (fresh run) or resumes from a checkpoint repo, submits the job, and tails logs
until completion or detaches immediately. Ctrl-C detaches without cancelling the remote job.
"""
token = get_token()
if not token:
@@ -233,8 +297,20 @@ def submit_to_hf(cfg: TrainPipelineConfig) -> None:
user_info = api.whoami(token=token)
username = user_info["name"]
# validate() resolves a `--policy.path=...` policy into cfg.policy and skips its
# repo_id requirement for remote runs (we assign one below), so it's safe to run first.
now = dt.datetime.now(dt.UTC)
fresh_repo_id: str | None = None
if not cfg.resume:
# Resolve the model repo and mark it for push BEFORE validate(): validate() requires repo_id
# to be set whenever push_to_hub is True. (A resume reuses the checkpoint's repo instead.)
if cfg.policy is not None:
base_name = cfg.job_name or cfg.policy.type
fresh_repo_id = cfg.policy.repo_id or build_repo_id(username, base_name, now)
cfg.policy.repo_id = fresh_repo_id
cfg.policy.push_to_hub = True
else:
# Path-based policy is resolved inside validate(); fall back to a generic slug.
fresh_repo_id = build_repo_id(username, cfg.job_name or "train", now)
cfg.validate()
if cfg.is_reward_model_training:
@@ -243,14 +319,6 @@ def submit_to_hf(cfg: TrainPipelineConfig) -> None:
"Run reward-model training locally."
)
# Auto-generate the model repo unless the user pinned one. cfg.policy is guaranteed
# set here (validate() raises if neither policy nor reward_model is configured, and
# reward-model runs are rejected above).
now = dt.datetime.now(dt.UTC)
repo_id = cfg.policy.repo_id or build_repo_id(username, cfg.job_name or cfg.policy.type, now)
cfg.policy.repo_id = repo_id
cfg.policy.push_to_hub = True
secrets: dict[str, str] = {"HF_TOKEN": token}
if cfg.wandb.enable:
wandb_key = resolve_wandb_api_key()
@@ -262,10 +330,16 @@ def submit_to_hf(cfg: TrainPipelineConfig) -> None:
secrets["WANDB_API_KEY"] = wandb_key
tags = resolve_job_tags(cfg.job.tags)
# The dataset must be reachable from the pod for both fresh and resumed runs; a local-only
# dataset is pushed PRIVATE here. Hoisted before the resume/fresh branch since it applies to both.
ensure_dataset_available(cfg.dataset.repo_id, api=api, tags=tags)
config_repo_id = _stage_config_on_hub(cfg, repo_id, token, tags=tags)
command = ["lerobot-train", f"--config_path={config_repo_id}"]
if cfg.resume:
repo_id, command = _build_resume_job(cfg, username)
else:
config_repo_id = _stage_config_on_hub(cfg, fresh_repo_id, token, tags=tags)
repo_id = fresh_repo_id
command = ["lerobot-train", f"--config_path={config_repo_id}"]
print(f"Submitting job to HF Jobs (flavor={cfg.job.target}, image={cfg.job.image}) ...")
job_info = run_job(