mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-26 04:37:01 +00:00
feat(jobs): resume a run on HF Jobs from a checkpoint
When --resume is set with a remote --job.target, submit_to_hf resumes from the checkpoint repo instead of staging a fresh config. A Hub config_path is resumed in place (its checkpoint config already targets that repo); a local config_path has its checkpoint uploaded to a new private repo first and the run is forced to push back to it. The pod command carries --job.target=local so the checkpoint's saved job.target can't make the pod re-dispatch itself, and the user's CLI overrides are forwarded so a remote resume matches the same local command. ensure_dataset_available is hoisted before the resume/fresh branch since it applies to both.
This commit is contained in:
+89
-15
@@ -26,6 +26,7 @@ import netrc
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
from pathlib import Path
|
||||
@@ -42,6 +43,8 @@ from huggingface_hub import (
|
||||
upload_file,
|
||||
)
|
||||
|
||||
from lerobot.common.train_utils import push_checkpoint_to_hub
|
||||
from lerobot.configs import parser
|
||||
from lerobot.jobs.dataset import ensure_dataset_available
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -218,12 +221,73 @@ def _poll_until_done(
|
||||
return None
|
||||
|
||||
|
||||
def _pod_forwarded_args(
|
||||
argv: list[str], drop_names: tuple[str, ...] = (), drop_prefixes: tuple[str, ...] = ()
|
||||
) -> list[str]:
|
||||
"""User CLI overrides to replay on the pod, minus flags the submitter sets itself.
|
||||
|
||||
Handles both `--name=value` and `--name value` forms. Forwarding the user's overrides (e.g.
|
||||
`--steps`, `--save_checkpoint_to_hub`) makes a remote resume behave like the same local command.
|
||||
"""
|
||||
out: list[str] = []
|
||||
skip_next = False
|
||||
for i, tok in enumerate(argv):
|
||||
if skip_next:
|
||||
skip_next = False
|
||||
continue
|
||||
name = tok.split("=", 1)[0]
|
||||
if name in drop_names or any(name.startswith(p) for p in drop_prefixes):
|
||||
if "=" not in tok and i + 1 < len(argv) and not argv[i + 1].startswith("--"):
|
||||
skip_next = True # also drop the space-separated value
|
||||
continue
|
||||
out.append(tok)
|
||||
return out
|
||||
|
||||
|
||||
def _build_resume_job(cfg: TrainPipelineConfig, username: str) -> tuple[str, list[str]]:
|
||||
"""Resolve the model repo and pod command to resume a run on a job.
|
||||
|
||||
A Hub `config_path` is resumed from directly: its checkpoint config already targets that repo,
|
||||
so new checkpoints continue the lineage there. A local `config_path` has its checkpoint uploaded
|
||||
to a new PRIVATE repo first, and the resumed run is forced to push back to it. The pod command
|
||||
always carries `--job.target=local` so the checkpoint's saved `job.target` can't make the pod
|
||||
re-dispatch itself.
|
||||
"""
|
||||
config_path = parser.parse_arg("config_path")
|
||||
forwarded = _pod_forwarded_args(
|
||||
sys.argv[1:],
|
||||
drop_names=("--config_path", "--policy.repo_id", "--policy.push_to_hub"),
|
||||
drop_prefixes=("--job.",),
|
||||
)
|
||||
|
||||
if Path(config_path).exists():
|
||||
# Local checkpoint: stage it on the Hub so the pod can resume from it, and push back there.
|
||||
# Resolve so a `last` symlink uploads under its real step name (digit), which the pod's
|
||||
# latest-checkpoint lookup keys on.
|
||||
checkpoint_dir = Path(cfg.checkpoint_path).resolve()
|
||||
source_repo = build_repo_id(username, cfg.job_name or "train", dt.datetime.now(dt.UTC))
|
||||
push_checkpoint_to_hub(checkpoint_dir, source_repo, private=True)
|
||||
extra = [f"--policy.repo_id={source_repo}", "--policy.push_to_hub=true"]
|
||||
else:
|
||||
source_repo = config_path
|
||||
extra = []
|
||||
|
||||
command = [
|
||||
"lerobot-train",
|
||||
*forwarded,
|
||||
f"--config_path={source_repo}",
|
||||
"--job.target=local",
|
||||
*extra,
|
||||
]
|
||||
return source_repo, command
|
||||
|
||||
|
||||
def submit_to_hf(cfg: TrainPipelineConfig) -> None:
|
||||
"""Submit a training job to HF Jobs infrastructure.
|
||||
|
||||
Validates cfg, resolves credentials, stages the config on the Hub, submits
|
||||
the job, then either tails logs until completion or detaches immediately.
|
||||
Ctrl-C detaches without cancelling the remote job.
|
||||
Validates cfg, resolves credentials, ensures the dataset is on the Hub, then either stages a
|
||||
sanitized config (fresh run) or resumes from a checkpoint repo, submits the job, and tails logs
|
||||
until completion or detaches immediately. Ctrl-C detaches without cancelling the remote job.
|
||||
"""
|
||||
token = get_token()
|
||||
if not token:
|
||||
@@ -233,8 +297,20 @@ def submit_to_hf(cfg: TrainPipelineConfig) -> None:
|
||||
user_info = api.whoami(token=token)
|
||||
username = user_info["name"]
|
||||
|
||||
# validate() resolves a `--policy.path=...` policy into cfg.policy and skips its
|
||||
# repo_id requirement for remote runs (we assign one below), so it's safe to run first.
|
||||
now = dt.datetime.now(dt.UTC)
|
||||
fresh_repo_id: str | None = None
|
||||
if not cfg.resume:
|
||||
# Resolve the model repo and mark it for push BEFORE validate(): validate() requires repo_id
|
||||
# to be set whenever push_to_hub is True. (A resume reuses the checkpoint's repo instead.)
|
||||
if cfg.policy is not None:
|
||||
base_name = cfg.job_name or cfg.policy.type
|
||||
fresh_repo_id = cfg.policy.repo_id or build_repo_id(username, base_name, now)
|
||||
cfg.policy.repo_id = fresh_repo_id
|
||||
cfg.policy.push_to_hub = True
|
||||
else:
|
||||
# Path-based policy is resolved inside validate(); fall back to a generic slug.
|
||||
fresh_repo_id = build_repo_id(username, cfg.job_name or "train", now)
|
||||
|
||||
cfg.validate()
|
||||
|
||||
if cfg.is_reward_model_training:
|
||||
@@ -243,14 +319,6 @@ def submit_to_hf(cfg: TrainPipelineConfig) -> None:
|
||||
"Run reward-model training locally."
|
||||
)
|
||||
|
||||
# Auto-generate the model repo unless the user pinned one. cfg.policy is guaranteed
|
||||
# set here (validate() raises if neither policy nor reward_model is configured, and
|
||||
# reward-model runs are rejected above).
|
||||
now = dt.datetime.now(dt.UTC)
|
||||
repo_id = cfg.policy.repo_id or build_repo_id(username, cfg.job_name or cfg.policy.type, now)
|
||||
cfg.policy.repo_id = repo_id
|
||||
cfg.policy.push_to_hub = True
|
||||
|
||||
secrets: dict[str, str] = {"HF_TOKEN": token}
|
||||
if cfg.wandb.enable:
|
||||
wandb_key = resolve_wandb_api_key()
|
||||
@@ -262,10 +330,16 @@ def submit_to_hf(cfg: TrainPipelineConfig) -> None:
|
||||
secrets["WANDB_API_KEY"] = wandb_key
|
||||
|
||||
tags = resolve_job_tags(cfg.job.tags)
|
||||
# The dataset must be reachable from the pod for both fresh and resumed runs; a local-only
|
||||
# dataset is pushed PRIVATE here. Hoisted before the resume/fresh branch since it applies to both.
|
||||
ensure_dataset_available(cfg.dataset.repo_id, api=api, tags=tags)
|
||||
|
||||
config_repo_id = _stage_config_on_hub(cfg, repo_id, token, tags=tags)
|
||||
command = ["lerobot-train", f"--config_path={config_repo_id}"]
|
||||
if cfg.resume:
|
||||
repo_id, command = _build_resume_job(cfg, username)
|
||||
else:
|
||||
config_repo_id = _stage_config_on_hub(cfg, fresh_repo_id, token, tags=tags)
|
||||
repo_id = fresh_repo_id
|
||||
command = ["lerobot-train", f"--config_path={config_repo_id}"]
|
||||
|
||||
print(f"Submitting job to HF Jobs (flavor={cfg.job.target}, image={cfg.job.image}) ...")
|
||||
job_info = run_job(
|
||||
|
||||
Reference in New Issue
Block a user