feat(train): run training remotely on HF Jobs via --job.target

When --job.target names a GPU flavor, train() dispatches to lerobot.jobs.submit_to_hf
instead of training locally: it authenticates, ensures the dataset is on the Hub
(pushing a local-only one privately), serializes a pod-compatible train_config.json
(strips client-only fields, points at the model repo), submits via HfApi.run_job
with HF_TOKEN/WANDB_API_KEY secrets, then streams logs and finishes when the model
is pushed. Wires push_checkpoint_to_hub into the training loop behind
save_checkpoint_to_hub, and tags jobs/datasets/model with 'lerobot' + --job.tags.
This commit is contained in:
Nicolas Rabault
2026-06-22 15:43:53 +02:00
parent 870d71aeb0
commit 607a8a6b68
7 changed files with 1002 additions and 0 deletions
+17
View File
@@ -0,0 +1,17 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .hf import submit_to_hf
__all__ = ["submit_to_hf"]
+56
View File
@@ -0,0 +1,56 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Make a training dataset reachable from an HF Job pod.
The pod can't see the host's ~/.cache/huggingface/lerobot, so the dataset has to
live on the Hub: the pod downloads it by repo_id at train time (the forwarded
HF_TOKEN covers private datasets). A dataset already on the Hub is used as-is; a
local-only dataset is pushed to a PRIVATE repo first (never public).
"""
from __future__ import annotations
import os
from pathlib import Path
from huggingface_hub.errors import RepositoryNotFoundError
def ensure_dataset_available(repo_id: str, *, api, tags: list[str] | None = None) -> None:
"""Ensure repo_id resolves on the Hub, pushing a local-only dataset privately first.
`tags` are attached to the dataset only when we push it (an already-on-Hub
dataset is left untouched). Raises RuntimeError if the dataset is neither on
the Hub nor in the local cache.
"""
try:
api.dataset_info(repo_id)
return
except RepositoryNotFoundError:
pass
cache_root = Path(os.environ.get("HF_LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
local_present = (cache_root / repo_id / "meta" / "info.json").is_file()
if not local_present:
raise RuntimeError(
f"Dataset '{repo_id}' is neither on the Hub nor in the local cache "
f"({cache_root}). Record or download it first."
)
print(f"[dataset] '{repo_id}' is local-only; pushing to a PRIVATE Hub repo...")
# Lazy import: LeRobotDataset pulls in heavy dataset deps; defer until actually needed.
from lerobot.datasets import LeRobotDataset
LeRobotDataset(repo_id).push_to_hub(private=True, tags=tags)
print(f"[dataset] '{repo_id}' uploaded (private). The job will download it by repo_id.")
+332
View File
@@ -0,0 +1,332 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run a lerobot training on HF Jobs (HuggingFace GPUs).
Ported and simplified from lelab's runners/hf_cloud.py: no UI log queue, no
registry — just submit and stream to stdout.
"""
from __future__ import annotations
import copy
import datetime as dt
import io
import json
import netrc
import os
import re
import signal
import tempfile
import threading
from pathlib import Path
from typing import TYPE_CHECKING
import draccus
from huggingface_hub import get_token
if TYPE_CHECKING:
from lerobot.configs.train import TrainPipelineConfig
_SLUG_RE = re.compile(r"[^a-zA-Z0-9._-]+")
_TERMINAL_STAGES = {"COMPLETED", "CANCELED", "ERROR", "DELETED"}
# Always attached to remote jobs and pushed datasets so LeRobot-originated work
# is identifiable on the Hub; callers (e.g. LeLab) add their own via --job.tags.
LEROBOT_TAG = "lerobot"
def resolve_job_tags(extra: list[str] | None) -> list[str]:
"""Return the tag list for a run: the lerobot tag plus any extras, deduped, order-stable."""
tags = [LEROBOT_TAG, *(extra or [])]
seen: set[str] = set()
return [t for t in tags if not (t in seen or seen.add(t))]
def resolve_wandb_api_key() -> str | None:
"""Host's wandb key for forwarding to the job: $WANDB_API_KEY, else ~/.netrc."""
key = os.environ.get("WANDB_API_KEY")
if key:
return key
try:
rc = netrc.netrc()
except (FileNotFoundError, netrc.NetrcParseError, OSError):
return None
auth = rc.authenticators("api.wandb.ai")
if auth is None:
return None
_login, _account, password = auth
return password or None
def build_repo_id(username: str, job_name: str, now: dt.datetime) -> str:
"""Generate the model repo id for a remote run: <user>/<job_name>_<timestamp>."""
slug = _SLUG_RE.sub("-", job_name).strip("-") or "train"
stamp = now.strftime("%Y-%m-%d_%H-%M-%S")
return f"{username}/{slug}_{stamp}"
def build_remote_config_file(cfg, repo_id: str, dest: Path, tags: list[str] | None = None) -> Path:
"""Write a train_config.json for the pod, with remote overrides applied.
The pod runs `lerobot-train --config_path=<dest>` and downloads the dataset
by repo_id into its own cache. Client-only fields are stripped so the config
is accepted by the trainer image: `job` (pure client orchestration) is always
removed, and `save_checkpoint_to_hub` is removed unless explicitly enabled —
older lerobot images reject unknown keys, so the default keeps the config
compatible with the released `lerobot-gpu` image. `tags` are merged into
policy.tags so the trained model the pod pushes carries them too.
"""
remote = copy.deepcopy(cfg)
remote.policy.push_to_hub = True
remote.policy.repo_id = repo_id
# Don't pin the client's resolved device (e.g. "mps"); let the pod auto-detect its GPU.
remote.policy.device = None
# Drop any host-local dataset root; the pod resolves the dataset by repo_id.
remote.dataset.root = None
if tags:
existing = list(remote.policy.tags or [])
remote.policy.tags = existing + [t for t in tags if t not in existing]
# Round-trip through draccus to get the canonical, pod-parseable layout, then
# drop the keys the released trainer image doesn't know about.
buf = io.StringIO()
with draccus.config_type("json"):
draccus.dump(remote, buf, indent=4)
data = json.loads(buf.getvalue())
data.pop("job", None)
if not remote.save_checkpoint_to_hub:
data.pop("save_checkpoint_to_hub", None)
dest.parent.mkdir(parents=True, exist_ok=True)
dest.write_text(json.dumps(data, indent=4))
return dest
def _stage_config_on_hub(cfg, repo_id: str, token: str, tags: list[str] | None = None) -> str:
"""Upload train_config.json to the model repo and return the repo_id for --config_path."""
from huggingface_hub import create_repo, upload_file
create_repo(repo_id, repo_type="model", private=True, exist_ok=True, token=token)
with tempfile.TemporaryDirectory() as tmp:
config_path = build_remote_config_file(cfg, repo_id, Path(tmp) / "train_config.json", tags=tags)
upload_file(
path_or_fileobj=config_path,
path_in_repo="train_config.json",
repo_id=repo_id,
repo_type="model",
token=token,
)
return repo_id
def _tail_logs(
job_id: str,
done: threading.Event,
success_marker: str | None = None,
success_event: threading.Event | None = None,
) -> None:
"""Stream job logs to stdout, reconnecting on dropped streams until done is set.
Each reconnect re-fetches the full buffered log, so we track how many lines
were already printed and skip them — otherwise a fast-failing job's traceback
gets reprinted on every reconnect.
When `success_marker` appears in a line, set `success_event` and `done` so the
caller can finish as soon as the trained model lands on the Hub, rather than
waiting out the platform's post-run finalization (which can add ~30s).
"""
from huggingface_hub import fetch_job_logs
printed = 0
while not done.is_set():
try:
seen = 0
for line in fetch_job_logs(job_id=job_id, follow=True):
seen += 1
if seen <= printed:
continue # already shown on a previous connection
printed = seen
# fetch_job_logs yields SSE data without trailing newlines, so add one
# per entry — otherwise all log lines concatenate onto a single line.
print(line.rstrip("\n"), flush=True)
if success_marker and success_event is not None and success_marker in line:
success_event.set()
done.set()
return
if done.is_set():
return
# Stream closed cleanly. Wait a moment so the status poller can mark
# the job terminal before we reconnect (avoids re-tailing the buffer).
if done.wait(3):
return
except Exception:
if done.wait(2):
return
def _poll_until_done(
job_id: str,
done: threading.Event,
poll_interval: float = 5.0,
status_holder: dict | None = None,
max_failures: int = 6,
) -> str | None:
"""Poll inspect_job until a terminal stage or until `done` is set.
Returns the terminal stage string, or None if `done` was set first (detach)
or after `max_failures` consecutive inspect_job errors. When a terminal stage
is reached and `status_holder` is given, records `status_holder["message"]`
(the platform's status message, e.g. "Job timeout").
"""
from huggingface_hub import inspect_job
failures = 0
while not done.is_set():
try:
info = inspect_job(job_id=job_id)
failures = 0
stage = info.status.stage.value
if stage in _TERMINAL_STAGES:
if status_holder is not None:
status_holder["message"] = getattr(info.status, "message", None)
done.set()
return stage
except Exception:
failures += 1
if failures >= max_failures:
done.set()
return None
done.wait(poll_interval)
return None
def submit_to_hf(cfg: TrainPipelineConfig) -> None:
"""Submit a training job to HF Jobs infrastructure.
Validates cfg, resolves credentials, stages the config on the Hub, submits
the job, then either tails logs until completion or detaches immediately.
Ctrl-C detaches without cancelling the remote job.
"""
from huggingface_hub import HfApi, run_job
from lerobot.jobs.dataset import ensure_dataset_available
token = get_token()
if not token:
raise RuntimeError("Not logged in to Hugging Face. Run `hf auth login` first.")
api = HfApi(token=token)
user_info = api.whoami(token=token)
username = user_info["name"]
now = dt.datetime.now()
if cfg.policy is not None:
base_name = cfg.job_name or cfg.policy.type
repo_id = cfg.policy.repo_id or build_repo_id(username, base_name, now)
cfg.policy.repo_id = repo_id
cfg.policy.push_to_hub = True
else:
# Path-based policy is resolved inside validate(); fall back to a generic slug.
repo_id = build_repo_id(username, cfg.job_name or "train", now)
cfg.validate()
secrets: dict[str, str] = {"HF_TOKEN": token}
if cfg.wandb.enable:
wandb_key = resolve_wandb_api_key()
if wandb_key is None:
raise ValueError(
"wandb is enabled but no WANDB_API_KEY found. "
"Set it via `export WANDB_API_KEY=...` or add it to ~/.netrc."
)
secrets["WANDB_API_KEY"] = wandb_key
tags = resolve_job_tags(cfg.job.tags)
ensure_dataset_available(cfg.dataset.repo_id, api=api, tags=tags)
config_repo_id = _stage_config_on_hub(cfg, repo_id, token, tags=tags)
command = ["lerobot-train", f"--config_path={config_repo_id}"]
print(f"Submitting job to HF Jobs (flavor={cfg.job.target}, image={cfg.job.image}) ...")
job_info = run_job(
image=cfg.job.image,
command=command,
flavor=cfg.job.target,
secrets=secrets,
timeout=cfg.job.timeout,
# HF Jobs labels are key/value; expose each tag as a queryable label.
labels=dict.fromkeys(tags, "true"),
)
job_id = job_info.id
job_url = getattr(job_info, "url", None)
print(f"Job submitted: {job_id}")
if job_url:
print(f" Job page: {job_url}")
print(f" Model repo: https://huggingface.co/{repo_id}")
print(f" Monitor: hf jobs logs {job_id}")
print(f" Cancel: hf jobs cancel {job_id}")
if cfg.job.detach:
return
done = threading.Event()
detached = threading.Event()
pushed_ok = threading.Event()
stage_holder: dict[str, str | None] = {}
def _poll() -> None:
stage_holder["stage"] = _poll_until_done(job_id, done, status_holder=stage_holder)
poll_thread = threading.Thread(target=_poll, daemon=True)
poll_thread.start()
# Finish as soon as the model is pushed, rather than waiting out the platform's
# post-run finalization before the job stage flips to COMPLETED.
success_marker = f"Model pushed to https://huggingface.co/{repo_id}"
log_thread = threading.Thread(
target=_tail_logs, args=(job_id, done, success_marker, pushed_ok), daemon=True
)
log_thread.start()
def _detach(sig, frame):
detached.set()
done.set()
print("\nDetached. Job is still running.")
print(f" Monitor: hf jobs logs {job_id}")
print(f" Cancel: hf jobs cancel {job_id}")
original_sigint = signal.getsignal(signal.SIGINT)
signal.signal(signal.SIGINT, _detach)
try:
# Timeout-based join so SIGINT is delivered to the main thread promptly.
while poll_thread.is_alive():
poll_thread.join(timeout=0.5)
log_thread.join(timeout=5)
finally:
signal.signal(signal.SIGINT, original_sigint)
if detached.is_set():
return
if pushed_ok.is_set():
print(f"\nTraining complete — model pushed to https://huggingface.co/{repo_id}")
return
stage = stage_holder.get("stage")
if stage != "COMPLETED":
message = stage_holder.get("message")
detail = f" ({message})" if message else ""
raise RuntimeError(
f"Job {job_id} ended with stage={stage}{detail}. Check logs: hf jobs logs {job_id}"
)
+33
View File
@@ -41,6 +41,7 @@ from lerobot.common.train_utils import (
load_training_batch_size,
load_training_num_processes,
load_training_state,
push_checkpoint_to_hub,
save_checkpoint,
update_last_checkpoint,
)
@@ -188,6 +189,11 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
cfg: A `TrainPipelineConfig` object containing all training configurations.
accelerator: Optional Accelerator instance. If None, one will be created automatically.
"""
if cfg.job.is_remote:
from lerobot.jobs import submit_to_hf
return submit_to_hf(cfg)
from lerobot.utils.import_utils import require_package
require_package("accelerate", extra="training")
@@ -655,6 +661,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
optim_state_dict=optim_state_dict,
)
update_last_checkpoint(checkpoint_dir)
if cfg.save_checkpoint_to_hub:
push_checkpoint_to_hub(
checkpoint_dir,
cfg.policy.repo_id,
private=cfg.policy.private,
)
if wandb_logger:
wandb_logger.log_policy(checkpoint_dir)
@@ -735,8 +747,29 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
accelerator.end_training()
def _remote_target_in_argv() -> bool:
"""True when the CLI requests a remote HF Jobs run (--job.target=<non-local>)."""
import sys
from lerobot.configs.default import JobConfig
target = None
args = sys.argv[1:]
for i, tok in enumerate(args):
if tok == "--job.target" and i + 1 < len(args):
target = args[i + 1]
elif tok.startswith("--job.target="):
target = tok.split("=", 1)[1]
return JobConfig.is_remote_target(target)
def main():
register_third_party_plugins()
if _remote_target_in_argv():
# The policy device is resolved on the remote pod, not here, so silence the
# client-side "Device '...' is not available" warning PreTrainedConfig emits
# while parsing the config (it fires before train() can dispatch remotely).
logging.getLogger("lerobot.configs.policies").setLevel(logging.ERROR)
train()
+78
View File
@@ -0,0 +1,78 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from unittest.mock import MagicMock
import httpx
import pytest
from huggingface_hub.errors import RepositoryNotFoundError
from lerobot.jobs.dataset import ensure_dataset_available
def _repo_not_found() -> RepositoryNotFoundError:
req = httpx.Request("GET", "https://huggingface.co/datasets/test")
resp = httpx.Response(404, request=req)
return RepositoryNotFoundError("nope", response=resp)
def _api_with_dataset(exists: bool):
api = MagicMock()
if exists:
api.dataset_info.return_value = object()
else:
api.dataset_info.side_effect = _repo_not_found()
return api
def _make_local_cache(tmp_path, repo_id: str) -> None:
"""Create the minimal local-cache layout that ensure_dataset_available checks."""
info = tmp_path / repo_id / "meta" / "info.json"
info.parent.mkdir(parents=True)
info.write_text("{}")
# Branch 1: dataset already on Hub → no push, no error (pod downloads by repo_id).
def test_dataset_already_on_hub_is_noop():
api = _api_with_dataset(True)
assert ensure_dataset_available("user/ds", api=api) is None
api.dataset_info.assert_called_once_with("user/ds")
# Branch 2: not on Hub but present locally → always push privately.
def test_dataset_local_only_uploads_privately(tmp_path, monkeypatch):
monkeypatch.setenv("HF_LEROBOT_HOME", str(tmp_path))
_make_local_cache(tmp_path, "user/ds")
api = _api_with_dataset(False)
mock_ds_cls = MagicMock()
fake_datasets_module = MagicMock()
fake_datasets_module.LeRobotDataset = mock_ds_cls
monkeypatch.setitem(sys.modules, "lerobot.datasets", fake_datasets_module)
assert ensure_dataset_available("user/ds", api=api, tags=["lerobot", "lelab"]) is None
mock_ds_cls.assert_called_once_with("user/ds")
mock_ds_cls.return_value.push_to_hub.assert_called_once_with(private=True, tags=["lerobot", "lelab"])
# Branch 3: not on Hub, NOT in local cache → RuntimeError "neither".
def test_dataset_neither_on_hub_nor_local_raises(tmp_path, monkeypatch):
monkeypatch.setenv("HF_LEROBOT_HOME", str(tmp_path))
# tmp_path is empty — no local cache.
api = _api_with_dataset(False)
with pytest.raises(RuntimeError, match="neither"):
ensure_dataset_available("user/ds", api=api)
+426
View File
@@ -0,0 +1,426 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime as dt
import json
import threading
from types import SimpleNamespace
import draccus
import pytest
from lerobot.configs.train import TrainPipelineConfig
from lerobot.jobs.hf import (
_poll_until_done,
build_remote_config_file,
build_repo_id,
resolve_job_tags,
resolve_wandb_api_key,
submit_to_hf,
)
def test_resolve_job_tags_always_includes_lerobot_and_dedups():
assert resolve_job_tags(None) == ["lerobot"]
assert resolve_job_tags([]) == ["lerobot"]
assert resolve_job_tags(["lelab"]) == ["lerobot", "lelab"]
# lerobot isn't duplicated if passed explicitly; order is stable.
assert resolve_job_tags(["lelab", "lerobot", "lelab"]) == ["lerobot", "lelab"]
def _fake_inspect(stage_value):
return lambda job_id: SimpleNamespace(status=SimpleNamespace(stage=SimpleNamespace(value=stage_value)))
def test_poll_until_done_returns_terminal_stage(monkeypatch):
monkeypatch.setattr("huggingface_hub.inspect_job", _fake_inspect("COMPLETED"))
done = threading.Event()
assert _poll_until_done("j", done, poll_interval=0.01) == "COMPLETED"
assert done.is_set()
def test_poll_until_done_exits_when_done_already_set(monkeypatch):
# Non-terminal forever; with done pre-set the loop must not block and returns None.
monkeypatch.setattr("huggingface_hub.inspect_job", _fake_inspect("RUNNING"))
done = threading.Event()
done.set()
assert _poll_until_done("j", done, poll_interval=0.01) is None
def test_poll_until_done_gives_up_after_repeated_failures(monkeypatch):
monkeypatch.setattr(
"huggingface_hub.inspect_job", lambda job_id: (_ for _ in ()).throw(RuntimeError("boom"))
)
done = threading.Event()
result = _poll_until_done("j", done, poll_interval=0.001, max_failures=3)
assert result is None
assert done.is_set()
def test_resolve_wandb_key_from_env(monkeypatch):
monkeypatch.setenv("WANDB_API_KEY", "abc123")
assert resolve_wandb_api_key() == "abc123"
def test_resolve_wandb_key_missing(monkeypatch, tmp_path):
monkeypatch.delenv("WANDB_API_KEY", raising=False)
monkeypatch.setenv("HOME", str(tmp_path)) # no ~/.netrc here
monkeypatch.setattr("netrc.netrc", lambda *a, **k: (_ for _ in ()).throw(FileNotFoundError()))
assert resolve_wandb_api_key() is None
def test_resolve_wandb_key_from_netrc(monkeypatch):
# No env var → fall back to the wandb credentials in ~/.netrc.
monkeypatch.delenv("WANDB_API_KEY", raising=False)
class _FakeNetrc:
def authenticators(self, host):
assert host == "api.wandb.ai"
return ("login", "account", "netrc-secret")
monkeypatch.setattr("netrc.netrc", lambda *a, **k: _FakeNetrc())
assert resolve_wandb_api_key() == "netrc-secret"
def test_resolve_wandb_key_netrc_without_wandb_entry(monkeypatch):
# ~/.netrc exists but has no api.wandb.ai entry → None.
monkeypatch.delenv("WANDB_API_KEY", raising=False)
class _FakeNetrc:
def authenticators(self, host):
return None
monkeypatch.setattr("netrc.netrc", lambda *a, **k: _FakeNetrc())
assert resolve_wandb_api_key() is None
def test_build_repo_id_sanitizes_and_timestamps():
now = dt.datetime(2026, 6, 19, 10, 22, 3)
assert build_repo_id("alice", "act", now) == "alice/act_2026-06-19_10-22-03"
# Runs of illegal characters collapse to a single dash; edges are trimmed.
assert build_repo_id("alice", "my cool/run!!", now) == "alice/my-cool-run_2026-06-19_10-22-03"
# A name with nothing usable falls back to "train".
assert build_repo_id("alice", "///", now) == "alice/train_2026-06-19_10-22-03"
def _minimal_cfg():
return draccus.parse(
TrainPipelineConfig,
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
)
def test_build_remote_config_applies_overrides(tmp_path):
cfg = _minimal_cfg()
dest = tmp_path / "train_config.json"
out = build_remote_config_file(cfg, "u/run", dest)
assert out == dest
data = json.loads(dest.read_text())
# `job` is client-only orchestration and must be stripped for the pod.
assert "job" not in data
# save_checkpoint_to_hub defaults off → omitted so older images accept the config.
assert "save_checkpoint_to_hub" not in data
assert data["policy"]["push_to_hub"] is True
assert data["policy"]["repo_id"] == "u/run"
assert data["policy"]["device"] is None # pod auto-detects its GPU
assert data["dataset"]["root"] is None # pod resolves the dataset by repo_id
# the caller's cfg must be left untouched (function works on a deep copy)
assert cfg.job.target == "a10g-small"
assert cfg.save_checkpoint_to_hub is False
def test_build_remote_config_includes_checkpoint_flag_when_enabled(tmp_path):
cfg = draccus.parse(
TrainPipelineConfig,
args=[
"--dataset.repo_id",
"u/d",
"--policy.type",
"act",
"--job.target",
"a10g-small",
"--save_checkpoint_to_hub",
"true",
],
)
dest = tmp_path / "train_config.json"
build_remote_config_file(cfg, "u/run", dest)
data = json.loads(dest.read_text())
# explicitly enabled → kept in the config (requires a matching trainer image).
assert data["save_checkpoint_to_hub"] is True
assert "job" not in data
def test_build_remote_config_merges_tags_into_policy(tmp_path):
cfg = _minimal_cfg()
dest = tmp_path / "train_config.json"
build_remote_config_file(cfg, "u/run", dest, tags=["lerobot", "lelab"])
data = json.loads(dest.read_text())
# tags propagate to the model the pod pushes.
assert data["policy"]["tags"] == ["lerobot", "lelab"]
def test_build_remote_config_merges_tags_without_duplicating(tmp_path):
cfg = _minimal_cfg()
cfg.policy.tags = ["existing", "lerobot"]
dest = tmp_path / "train_config.json"
build_remote_config_file(cfg, "u/run", dest, tags=["lerobot", "lelab"])
data = json.loads(dest.read_text())
# pre-existing policy tags are kept; only genuinely-new tags are appended (no dup "lerobot").
assert data["policy"]["tags"] == ["existing", "lerobot", "lelab"]
def test_submit_requires_login(monkeypatch):
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: None)
cfg = draccus.parse(
TrainPipelineConfig,
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
)
with pytest.raises(RuntimeError, match="hf auth login"):
submit_to_hf(cfg)
def test_submit_passes_validation_and_submits(monkeypatch):
"""Regression: repo_id must be set BEFORE cfg.validate() or validation raises."""
from unittest.mock import MagicMock
import huggingface_hub
# Patch get_token
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
# Patch HfApi so whoami returns alice
class FakeHfApi:
def __init__(self, token=None):
pass
def whoami(self, token=None):
return {"name": "alice"}
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
# ensure_dataset_available returns None; patch it out so no Hub access happens
# (imported inside submit_to_hf via `from lerobot.jobs.dataset import ensure_dataset_available`).
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
# Patch _stage_config_on_hub to skip network
monkeypatch.setattr(
"lerobot.jobs.hf._stage_config_on_hub",
lambda cfg, repo_id, token, tags=None: repo_id,
)
# Patch run_job to return a fake job
fake_job = MagicMock()
fake_job.id = "job-123"
run_job_calls = []
def fake_run_job(**kwargs):
run_job_calls.append(kwargs)
return fake_job
monkeypatch.setattr(huggingface_hub, "run_job", fake_run_job)
cfg = draccus.parse(
TrainPipelineConfig,
args=[
"--dataset.repo_id",
"u/d",
"--policy.type",
"act",
"--job.target",
"a10g-small",
"--job.detach",
"true",
],
)
# Must NOT raise (pre-fix this raised ValueError about missing repo_id)
submit_to_hf(cfg)
assert len(run_job_calls) == 1, "run_job should have been called exactly once"
assert cfg.policy.repo_id is not None
assert cfg.policy.repo_id.startswith("alice/")
call = run_job_calls[0]
# The pod runs `lerobot-train --config_path=<staged repo>` on the requested flavor/image.
assert call["command"][0] == "lerobot-train"
assert call["command"][1].startswith("--config_path=")
assert call["flavor"] == "a10g-small"
assert call["image"] == "huggingface/lerobot-gpu:latest"
# The Hub token is forwarded so the pod can pull the (possibly private) dataset.
assert call["secrets"]["HF_TOKEN"] == "tok"
# Every job carries the lerobot tag as a queryable label.
assert call["labels"].get("lerobot") == "true"
@pytest.mark.timeout(15)
def test_submit_returns_when_job_completes(monkeypatch):
"""Non-detach path must RETURN (not hang) once the job reaches a terminal stage."""
from types import SimpleNamespace
import huggingface_hub
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
class FakeHfApi:
def __init__(self, token=None):
pass
def whoami(self, token=None):
return {"name": "alice"}
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
monkeypatch.setattr(
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
)
monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
# Job is already COMPLETED on the first poll.
monkeypatch.setattr(
"huggingface_hub.inspect_job",
lambda job_id: SimpleNamespace(
status=SimpleNamespace(stage=SimpleNamespace(value="COMPLETED"), message=None)
),
)
# Log stream ends immediately.
monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter(()))
cfg = draccus.parse(
TrainPipelineConfig,
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
)
# Runs in the pytest main thread (signal handler install requires it); the
# @timeout marker fails the test instead of hanging if it regresses.
submit_to_hf(cfg)
@pytest.mark.timeout(15)
def test_submit_returns_on_model_pushed_marker(monkeypatch):
"""Finish when the model-pushed log appears, even if the job stage never flips."""
from types import SimpleNamespace
import huggingface_hub
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
class FakeHfApi:
def __init__(self, token=None):
pass
def whoami(self, token=None):
return {"name": "alice"}
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
monkeypatch.setattr(
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
)
monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
# Job stays RUNNING forever — only the log marker can end the command.
monkeypatch.setattr(
"huggingface_hub.inspect_job",
lambda job_id: SimpleNamespace(
status=SimpleNamespace(stage=SimpleNamespace(value="RUNNING"), message=None)
),
)
pushed_line = "INFO Model pushed to https://huggingface.co/alice/myrun"
monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter([pushed_line]))
cfg = draccus.parse(
TrainPipelineConfig,
args=[
"--dataset.repo_id",
"u/d",
"--policy.type",
"act",
"--policy.repo_id",
"alice/myrun",
"--job.target",
"a10g-small",
],
)
# Must return via the model-pushed marker despite the perpetual RUNNING stage.
submit_to_hf(cfg)
def test_submit_raises_when_wandb_enabled_without_key(monkeypatch):
"""wandb.enable with no key reachable anywhere fails fast, before submitting."""
import huggingface_hub
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
class FakeHfApi:
def __init__(self, token=None):
pass
def whoami(self, token=None):
return {"name": "alice"}
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
monkeypatch.setattr("lerobot.jobs.hf.resolve_wandb_api_key", lambda: None)
cfg = draccus.parse(
TrainPipelineConfig,
args=[
"--dataset.repo_id",
"u/d",
"--policy.type",
"act",
"--job.target",
"a10g-small",
"--wandb.enable",
"true",
],
)
with pytest.raises(ValueError, match="WANDB_API_KEY"):
submit_to_hf(cfg)
@pytest.mark.timeout(15)
def test_submit_raises_when_job_ends_in_error(monkeypatch):
"""A terminal non-COMPLETED stage with no model-pushed marker must raise with the status."""
from types import SimpleNamespace
import huggingface_hub
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
class FakeHfApi:
def __init__(self, token=None):
pass
def whoami(self, token=None):
return {"name": "alice"}
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
monkeypatch.setattr(
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
)
monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
# Job fails: a terminal ERROR stage carrying the platform's status message.
monkeypatch.setattr(
"huggingface_hub.inspect_job",
lambda job_id: SimpleNamespace(
status=SimpleNamespace(stage=SimpleNamespace(value="ERROR"), message="Job timeout")
),
)
# Logs end without the model-pushed marker.
monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter(()))
cfg = draccus.parse(
TrainPipelineConfig,
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
)
with pytest.raises(RuntimeError, match=r"stage=ERROR \(Job timeout\)"):
submit_to_hf(cfg)
@@ -0,0 +1,60 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import draccus
from lerobot.configs.train import TrainPipelineConfig
from lerobot.policies.act.configuration_act import ACTConfig # noqa: F401 (registers --policy.type act)
from lerobot.scripts.lerobot_train import _remote_target_in_argv, train
def _set_argv(monkeypatch, *args):
monkeypatch.setattr(sys, "argv", ["lerobot-train", *args])
def test_remote_target_detected_space_separated(monkeypatch):
_set_argv(monkeypatch, "--policy.type", "act", "--job.target", "a10g-small")
assert _remote_target_in_argv() is True
def test_remote_target_detected_equals(monkeypatch):
_set_argv(monkeypatch, "--job.target=t4-small")
assert _remote_target_in_argv() is True
def test_local_string_is_not_remote(monkeypatch):
_set_argv(monkeypatch, "--job.target", "local")
assert _remote_target_in_argv() is False
def test_no_target_is_not_remote(monkeypatch):
_set_argv(monkeypatch, "--policy.type", "act")
assert _remote_target_in_argv() is False
def test_train_dispatches_to_submit_when_remote(monkeypatch):
"""A remote --job.target short-circuits train() to the HF Jobs submitter."""
import lerobot.jobs
captured = []
monkeypatch.setattr(lerobot.jobs, "submit_to_hf", lambda cfg: captured.append(cfg) or "submitted")
cfg = draccus.parse(
TrainPipelineConfig,
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
)
# Returns the submitter's result and never enters the local training path.
assert train(cfg) == "submitted"
assert captured == [cfg]