From 05fddeb2bacf2baa93a01c09629f479491ed8849 Mon Sep 17 00:00:00 2001 From: Nicolas Rabault Date: Mon, 22 Jun 2026 15:43:53 +0200 Subject: [PATCH] feat(train): run training remotely on HF Jobs via --job.target When --job.target names a GPU flavor, train() dispatches to lerobot.jobs.submit_to_hf instead of training locally: it authenticates, ensures the dataset is on the Hub (pushing a local-only one privately), serializes a pod-compatible train_config.json (strips client-only fields, points at the model repo), submits via HfApi.run_job with HF_TOKEN/WANDB_API_KEY secrets, then streams logs and finishes when the model is pushed. Wires push_checkpoint_to_hub into the training loop behind save_checkpoint_to_hub, and tags jobs/datasets/model with 'lerobot' + --job.tags. --- src/lerobot/jobs/__init__.py | 17 + src/lerobot/jobs/dataset.py | 56 +++ src/lerobot/jobs/hf.py | 332 +++++++++++++++ src/lerobot/scripts/lerobot_train.py | 33 ++ tests/jobs/test_dataset.py | 78 ++++ tests/jobs/test_hf.py | 426 ++++++++++++++++++++ tests/scripts/test_train_remote_dispatch.py | 60 +++ 7 files changed, 1002 insertions(+) create mode 100644 src/lerobot/jobs/__init__.py create mode 100644 src/lerobot/jobs/dataset.py create mode 100644 src/lerobot/jobs/hf.py create mode 100644 tests/jobs/test_dataset.py create mode 100644 tests/jobs/test_hf.py create mode 100644 tests/scripts/test_train_remote_dispatch.py diff --git a/src/lerobot/jobs/__init__.py b/src/lerobot/jobs/__init__.py new file mode 100644 index 000000000..b13133752 --- /dev/null +++ b/src/lerobot/jobs/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .hf import submit_to_hf + +__all__ = ["submit_to_hf"] diff --git a/src/lerobot/jobs/dataset.py b/src/lerobot/jobs/dataset.py new file mode 100644 index 000000000..21c978f62 --- /dev/null +++ b/src/lerobot/jobs/dataset.py @@ -0,0 +1,56 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Make a training dataset reachable from an HF Job pod. + +The pod can't see the host's ~/.cache/huggingface/lerobot, so the dataset has to +live on the Hub: the pod downloads it by repo_id at train time (the forwarded +HF_TOKEN covers private datasets). A dataset already on the Hub is used as-is; a +local-only dataset is pushed to a PRIVATE repo first (never public). +""" + +from __future__ import annotations + +import os +from pathlib import Path + +from huggingface_hub.errors import RepositoryNotFoundError + + +def ensure_dataset_available(repo_id: str, *, api, tags: list[str] | None = None) -> None: + """Ensure repo_id resolves on the Hub, pushing a local-only dataset privately first. + + `tags` are attached to the dataset only when we push it (an already-on-Hub + dataset is left untouched). Raises RuntimeError if the dataset is neither on + the Hub nor in the local cache. + """ + try: + api.dataset_info(repo_id) + return + except RepositoryNotFoundError: + pass + + cache_root = Path(os.environ.get("HF_LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser() + local_present = (cache_root / repo_id / "meta" / "info.json").is_file() + if not local_present: + raise RuntimeError( + f"Dataset '{repo_id}' is neither on the Hub nor in the local cache " + f"({cache_root}). Record or download it first." + ) + + print(f"[dataset] '{repo_id}' is local-only; pushing to a PRIVATE Hub repo...") + # Lazy import: LeRobotDataset pulls in heavy dataset deps; defer until actually needed. + from lerobot.datasets import LeRobotDataset + + LeRobotDataset(repo_id).push_to_hub(private=True, tags=tags) + print(f"[dataset] '{repo_id}' uploaded (private). The job will download it by repo_id.") diff --git a/src/lerobot/jobs/hf.py b/src/lerobot/jobs/hf.py new file mode 100644 index 000000000..035356a76 --- /dev/null +++ b/src/lerobot/jobs/hf.py @@ -0,0 +1,332 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Run a lerobot training on HF Jobs (HuggingFace GPUs). + +Ported and simplified from lelab's runners/hf_cloud.py: no UI log queue, no +registry — just submit and stream to stdout. +""" + +from __future__ import annotations + +import copy +import datetime as dt +import io +import json +import netrc +import os +import re +import signal +import tempfile +import threading +from pathlib import Path +from typing import TYPE_CHECKING + +import draccus +from huggingface_hub import get_token + +if TYPE_CHECKING: + from lerobot.configs.train import TrainPipelineConfig + +_SLUG_RE = re.compile(r"[^a-zA-Z0-9._-]+") + +_TERMINAL_STAGES = {"COMPLETED", "CANCELED", "ERROR", "DELETED"} + +# Always attached to remote jobs and pushed datasets so LeRobot-originated work +# is identifiable on the Hub; callers (e.g. LeLab) add their own via --job.tags. +LEROBOT_TAG = "lerobot" + + +def resolve_job_tags(extra: list[str] | None) -> list[str]: + """Return the tag list for a run: the lerobot tag plus any extras, deduped, order-stable.""" + tags = [LEROBOT_TAG, *(extra or [])] + seen: set[str] = set() + return [t for t in tags if not (t in seen or seen.add(t))] + + +def resolve_wandb_api_key() -> str | None: + """Host's wandb key for forwarding to the job: $WANDB_API_KEY, else ~/.netrc.""" + key = os.environ.get("WANDB_API_KEY") + if key: + return key + try: + rc = netrc.netrc() + except (FileNotFoundError, netrc.NetrcParseError, OSError): + return None + auth = rc.authenticators("api.wandb.ai") + if auth is None: + return None + _login, _account, password = auth + return password or None + + +def build_repo_id(username: str, job_name: str, now: dt.datetime) -> str: + """Generate the model repo id for a remote run: /_.""" + slug = _SLUG_RE.sub("-", job_name).strip("-") or "train" + stamp = now.strftime("%Y-%m-%d_%H-%M-%S") + return f"{username}/{slug}_{stamp}" + + +def build_remote_config_file(cfg, repo_id: str, dest: Path, tags: list[str] | None = None) -> Path: + """Write a train_config.json for the pod, with remote overrides applied. + + The pod runs `lerobot-train --config_path=` and downloads the dataset + by repo_id into its own cache. Client-only fields are stripped so the config + is accepted by the trainer image: `job` (pure client orchestration) is always + removed, and `save_checkpoint_to_hub` is removed unless explicitly enabled — + older lerobot images reject unknown keys, so the default keeps the config + compatible with the released `lerobot-gpu` image. `tags` are merged into + policy.tags so the trained model the pod pushes carries them too. + """ + remote = copy.deepcopy(cfg) + remote.policy.push_to_hub = True + remote.policy.repo_id = repo_id + # Don't pin the client's resolved device (e.g. "mps"); let the pod auto-detect its GPU. + remote.policy.device = None + # Drop any host-local dataset root; the pod resolves the dataset by repo_id. + remote.dataset.root = None + if tags: + existing = list(remote.policy.tags or []) + remote.policy.tags = existing + [t for t in tags if t not in existing] + + # Round-trip through draccus to get the canonical, pod-parseable layout, then + # drop the keys the released trainer image doesn't know about. + buf = io.StringIO() + with draccus.config_type("json"): + draccus.dump(remote, buf, indent=4) + data = json.loads(buf.getvalue()) + data.pop("job", None) + if not remote.save_checkpoint_to_hub: + data.pop("save_checkpoint_to_hub", None) + + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_text(json.dumps(data, indent=4)) + return dest + + +def _stage_config_on_hub(cfg, repo_id: str, token: str, tags: list[str] | None = None) -> str: + """Upload train_config.json to the model repo and return the repo_id for --config_path.""" + from huggingface_hub import create_repo, upload_file + + create_repo(repo_id, repo_type="model", private=True, exist_ok=True, token=token) + with tempfile.TemporaryDirectory() as tmp: + config_path = build_remote_config_file(cfg, repo_id, Path(tmp) / "train_config.json", tags=tags) + upload_file( + path_or_fileobj=config_path, + path_in_repo="train_config.json", + repo_id=repo_id, + repo_type="model", + token=token, + ) + return repo_id + + +def _tail_logs( + job_id: str, + done: threading.Event, + success_marker: str | None = None, + success_event: threading.Event | None = None, +) -> None: + """Stream job logs to stdout, reconnecting on dropped streams until done is set. + + Each reconnect re-fetches the full buffered log, so we track how many lines + were already printed and skip them — otherwise a fast-failing job's traceback + gets reprinted on every reconnect. + + When `success_marker` appears in a line, set `success_event` and `done` so the + caller can finish as soon as the trained model lands on the Hub, rather than + waiting out the platform's post-run finalization (which can add ~30s). + """ + from huggingface_hub import fetch_job_logs + + printed = 0 + while not done.is_set(): + try: + seen = 0 + for line in fetch_job_logs(job_id=job_id, follow=True): + seen += 1 + if seen <= printed: + continue # already shown on a previous connection + printed = seen + # fetch_job_logs yields SSE data without trailing newlines, so add one + # per entry — otherwise all log lines concatenate onto a single line. + print(line.rstrip("\n"), flush=True) + if success_marker and success_event is not None and success_marker in line: + success_event.set() + done.set() + return + if done.is_set(): + return + # Stream closed cleanly. Wait a moment so the status poller can mark + # the job terminal before we reconnect (avoids re-tailing the buffer). + if done.wait(3): + return + except Exception: + if done.wait(2): + return + + +def _poll_until_done( + job_id: str, + done: threading.Event, + poll_interval: float = 5.0, + status_holder: dict | None = None, + max_failures: int = 6, +) -> str | None: + """Poll inspect_job until a terminal stage or until `done` is set. + + Returns the terminal stage string, or None if `done` was set first (detach) + or after `max_failures` consecutive inspect_job errors. When a terminal stage + is reached and `status_holder` is given, records `status_holder["message"]` + (the platform's status message, e.g. "Job timeout"). + """ + from huggingface_hub import inspect_job + + failures = 0 + while not done.is_set(): + try: + info = inspect_job(job_id=job_id) + failures = 0 + stage = info.status.stage.value + if stage in _TERMINAL_STAGES: + if status_holder is not None: + status_holder["message"] = getattr(info.status, "message", None) + done.set() + return stage + except Exception: + failures += 1 + if failures >= max_failures: + done.set() + return None + done.wait(poll_interval) + return None + + +def submit_to_hf(cfg: TrainPipelineConfig) -> None: + """Submit a training job to HF Jobs infrastructure. + + Validates cfg, resolves credentials, stages the config on the Hub, submits + the job, then either tails logs until completion or detaches immediately. + Ctrl-C detaches without cancelling the remote job. + """ + from huggingface_hub import HfApi, run_job + + from lerobot.jobs.dataset import ensure_dataset_available + + token = get_token() + if not token: + raise RuntimeError("Not logged in to Hugging Face. Run `hf auth login` first.") + + api = HfApi(token=token) + user_info = api.whoami(token=token) + username = user_info["name"] + + now = dt.datetime.now() + if cfg.policy is not None: + base_name = cfg.job_name or cfg.policy.type + repo_id = cfg.policy.repo_id or build_repo_id(username, base_name, now) + cfg.policy.repo_id = repo_id + cfg.policy.push_to_hub = True + else: + # Path-based policy is resolved inside validate(); fall back to a generic slug. + repo_id = build_repo_id(username, cfg.job_name or "train", now) + + cfg.validate() + + secrets: dict[str, str] = {"HF_TOKEN": token} + if cfg.wandb.enable: + wandb_key = resolve_wandb_api_key() + if wandb_key is None: + raise ValueError( + "wandb is enabled but no WANDB_API_KEY found. " + "Set it via `export WANDB_API_KEY=...` or add it to ~/.netrc." + ) + secrets["WANDB_API_KEY"] = wandb_key + + tags = resolve_job_tags(cfg.job.tags) + ensure_dataset_available(cfg.dataset.repo_id, api=api, tags=tags) + + config_repo_id = _stage_config_on_hub(cfg, repo_id, token, tags=tags) + command = ["lerobot-train", f"--config_path={config_repo_id}"] + + print(f"Submitting job to HF Jobs (flavor={cfg.job.target}, image={cfg.job.image}) ...") + job_info = run_job( + image=cfg.job.image, + command=command, + flavor=cfg.job.target, + secrets=secrets, + timeout=cfg.job.timeout, + # HF Jobs labels are key/value; expose each tag as a queryable label. + labels=dict.fromkeys(tags, "true"), + ) + job_id = job_info.id + job_url = getattr(job_info, "url", None) + print(f"Job submitted: {job_id}") + if job_url: + print(f" Job page: {job_url}") + print(f" Model repo: https://huggingface.co/{repo_id}") + print(f" Monitor: hf jobs logs {job_id}") + print(f" Cancel: hf jobs cancel {job_id}") + + if cfg.job.detach: + return + + done = threading.Event() + detached = threading.Event() + pushed_ok = threading.Event() + stage_holder: dict[str, str | None] = {} + + def _poll() -> None: + stage_holder["stage"] = _poll_until_done(job_id, done, status_holder=stage_holder) + + poll_thread = threading.Thread(target=_poll, daemon=True) + poll_thread.start() + # Finish as soon as the model is pushed, rather than waiting out the platform's + # post-run finalization before the job stage flips to COMPLETED. + success_marker = f"Model pushed to https://huggingface.co/{repo_id}" + log_thread = threading.Thread( + target=_tail_logs, args=(job_id, done, success_marker, pushed_ok), daemon=True + ) + log_thread.start() + + def _detach(sig, frame): + detached.set() + done.set() + print("\nDetached. Job is still running.") + print(f" Monitor: hf jobs logs {job_id}") + print(f" Cancel: hf jobs cancel {job_id}") + + original_sigint = signal.getsignal(signal.SIGINT) + signal.signal(signal.SIGINT, _detach) + try: + # Timeout-based join so SIGINT is delivered to the main thread promptly. + while poll_thread.is_alive(): + poll_thread.join(timeout=0.5) + log_thread.join(timeout=5) + finally: + signal.signal(signal.SIGINT, original_sigint) + + if detached.is_set(): + return + + if pushed_ok.is_set(): + print(f"\nTraining complete — model pushed to https://huggingface.co/{repo_id}") + return + + stage = stage_holder.get("stage") + if stage != "COMPLETED": + message = stage_holder.get("message") + detail = f" ({message})" if message else "" + raise RuntimeError( + f"Job {job_id} ended with stage={stage}{detail}. Check logs: hf jobs logs {job_id}" + ) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 45281dac9..621f0cdda 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -41,6 +41,7 @@ from lerobot.common.train_utils import ( load_training_batch_size, load_training_num_processes, load_training_state, + push_checkpoint_to_hub, save_checkpoint, update_last_checkpoint, ) @@ -187,6 +188,11 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): cfg: A `TrainPipelineConfig` object containing all training configurations. accelerator: Optional Accelerator instance. If None, one will be created automatically. """ + if cfg.job.is_remote: + from lerobot.jobs import submit_to_hf + + return submit_to_hf(cfg) + from lerobot.utils.import_utils import require_package require_package("accelerate", extra="training") @@ -597,6 +603,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): optim_state_dict=optim_state_dict, ) update_last_checkpoint(checkpoint_dir) + if cfg.save_checkpoint_to_hub: + push_checkpoint_to_hub( + checkpoint_dir, + cfg.policy.repo_id, + private=cfg.policy.private, + ) if wandb_logger: wandb_logger.log_policy(checkpoint_dir) @@ -677,8 +689,29 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): accelerator.end_training() +def _remote_target_in_argv() -> bool: + """True when the CLI requests a remote HF Jobs run (--job.target=).""" + import sys + + from lerobot.configs.default import JobConfig + + target = None + args = sys.argv[1:] + for i, tok in enumerate(args): + if tok == "--job.target" and i + 1 < len(args): + target = args[i + 1] + elif tok.startswith("--job.target="): + target = tok.split("=", 1)[1] + return JobConfig.is_remote_target(target) + + def main(): register_third_party_plugins() + if _remote_target_in_argv(): + # The policy device is resolved on the remote pod, not here, so silence the + # client-side "Device '...' is not available" warning PreTrainedConfig emits + # while parsing the config (it fires before train() can dispatch remotely). + logging.getLogger("lerobot.configs.policies").setLevel(logging.ERROR) train() diff --git a/tests/jobs/test_dataset.py b/tests/jobs/test_dataset.py new file mode 100644 index 000000000..56cf640b2 --- /dev/null +++ b/tests/jobs/test_dataset.py @@ -0,0 +1,78 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from unittest.mock import MagicMock + +import httpx +import pytest +from huggingface_hub.errors import RepositoryNotFoundError + +from lerobot.jobs.dataset import ensure_dataset_available + + +def _repo_not_found() -> RepositoryNotFoundError: + req = httpx.Request("GET", "https://huggingface.co/datasets/test") + resp = httpx.Response(404, request=req) + return RepositoryNotFoundError("nope", response=resp) + + +def _api_with_dataset(exists: bool): + api = MagicMock() + if exists: + api.dataset_info.return_value = object() + else: + api.dataset_info.side_effect = _repo_not_found() + return api + + +def _make_local_cache(tmp_path, repo_id: str) -> None: + """Create the minimal local-cache layout that ensure_dataset_available checks.""" + info = tmp_path / repo_id / "meta" / "info.json" + info.parent.mkdir(parents=True) + info.write_text("{}") + + +# Branch 1: dataset already on Hub → no push, no error (pod downloads by repo_id). +def test_dataset_already_on_hub_is_noop(): + api = _api_with_dataset(True) + assert ensure_dataset_available("user/ds", api=api) is None + api.dataset_info.assert_called_once_with("user/ds") + + +# Branch 2: not on Hub but present locally → always push privately. +def test_dataset_local_only_uploads_privately(tmp_path, monkeypatch): + monkeypatch.setenv("HF_LEROBOT_HOME", str(tmp_path)) + _make_local_cache(tmp_path, "user/ds") + + api = _api_with_dataset(False) + mock_ds_cls = MagicMock() + fake_datasets_module = MagicMock() + fake_datasets_module.LeRobotDataset = mock_ds_cls + monkeypatch.setitem(sys.modules, "lerobot.datasets", fake_datasets_module) + + assert ensure_dataset_available("user/ds", api=api, tags=["lerobot", "lelab"]) is None + + mock_ds_cls.assert_called_once_with("user/ds") + mock_ds_cls.return_value.push_to_hub.assert_called_once_with(private=True, tags=["lerobot", "lelab"]) + + +# Branch 3: not on Hub, NOT in local cache → RuntimeError "neither". +def test_dataset_neither_on_hub_nor_local_raises(tmp_path, monkeypatch): + monkeypatch.setenv("HF_LEROBOT_HOME", str(tmp_path)) + # tmp_path is empty — no local cache. + + api = _api_with_dataset(False) + with pytest.raises(RuntimeError, match="neither"): + ensure_dataset_available("user/ds", api=api) diff --git a/tests/jobs/test_hf.py b/tests/jobs/test_hf.py new file mode 100644 index 000000000..818e99c88 --- /dev/null +++ b/tests/jobs/test_hf.py @@ -0,0 +1,426 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime as dt +import json +import threading +from types import SimpleNamespace + +import draccus +import pytest + +from lerobot.configs.train import TrainPipelineConfig +from lerobot.jobs.hf import ( + _poll_until_done, + build_remote_config_file, + build_repo_id, + resolve_job_tags, + resolve_wandb_api_key, + submit_to_hf, +) + + +def test_resolve_job_tags_always_includes_lerobot_and_dedups(): + assert resolve_job_tags(None) == ["lerobot"] + assert resolve_job_tags([]) == ["lerobot"] + assert resolve_job_tags(["lelab"]) == ["lerobot", "lelab"] + # lerobot isn't duplicated if passed explicitly; order is stable. + assert resolve_job_tags(["lelab", "lerobot", "lelab"]) == ["lerobot", "lelab"] + + +def _fake_inspect(stage_value): + return lambda job_id: SimpleNamespace(status=SimpleNamespace(stage=SimpleNamespace(value=stage_value))) + + +def test_poll_until_done_returns_terminal_stage(monkeypatch): + monkeypatch.setattr("huggingface_hub.inspect_job", _fake_inspect("COMPLETED")) + done = threading.Event() + assert _poll_until_done("j", done, poll_interval=0.01) == "COMPLETED" + assert done.is_set() + + +def test_poll_until_done_exits_when_done_already_set(monkeypatch): + # Non-terminal forever; with done pre-set the loop must not block and returns None. + monkeypatch.setattr("huggingface_hub.inspect_job", _fake_inspect("RUNNING")) + done = threading.Event() + done.set() + assert _poll_until_done("j", done, poll_interval=0.01) is None + + +def test_poll_until_done_gives_up_after_repeated_failures(monkeypatch): + monkeypatch.setattr( + "huggingface_hub.inspect_job", lambda job_id: (_ for _ in ()).throw(RuntimeError("boom")) + ) + done = threading.Event() + result = _poll_until_done("j", done, poll_interval=0.001, max_failures=3) + assert result is None + assert done.is_set() + + +def test_resolve_wandb_key_from_env(monkeypatch): + monkeypatch.setenv("WANDB_API_KEY", "abc123") + assert resolve_wandb_api_key() == "abc123" + + +def test_resolve_wandb_key_missing(monkeypatch, tmp_path): + monkeypatch.delenv("WANDB_API_KEY", raising=False) + monkeypatch.setenv("HOME", str(tmp_path)) # no ~/.netrc here + monkeypatch.setattr("netrc.netrc", lambda *a, **k: (_ for _ in ()).throw(FileNotFoundError())) + assert resolve_wandb_api_key() is None + + +def test_resolve_wandb_key_from_netrc(monkeypatch): + # No env var → fall back to the wandb credentials in ~/.netrc. + monkeypatch.delenv("WANDB_API_KEY", raising=False) + + class _FakeNetrc: + def authenticators(self, host): + assert host == "api.wandb.ai" + return ("login", "account", "netrc-secret") + + monkeypatch.setattr("netrc.netrc", lambda *a, **k: _FakeNetrc()) + assert resolve_wandb_api_key() == "netrc-secret" + + +def test_resolve_wandb_key_netrc_without_wandb_entry(monkeypatch): + # ~/.netrc exists but has no api.wandb.ai entry → None. + monkeypatch.delenv("WANDB_API_KEY", raising=False) + + class _FakeNetrc: + def authenticators(self, host): + return None + + monkeypatch.setattr("netrc.netrc", lambda *a, **k: _FakeNetrc()) + assert resolve_wandb_api_key() is None + + +def test_build_repo_id_sanitizes_and_timestamps(): + now = dt.datetime(2026, 6, 19, 10, 22, 3) + assert build_repo_id("alice", "act", now) == "alice/act_2026-06-19_10-22-03" + # Runs of illegal characters collapse to a single dash; edges are trimmed. + assert build_repo_id("alice", "my cool/run!!", now) == "alice/my-cool-run_2026-06-19_10-22-03" + # A name with nothing usable falls back to "train". + assert build_repo_id("alice", "///", now) == "alice/train_2026-06-19_10-22-03" + + +def _minimal_cfg(): + return draccus.parse( + TrainPipelineConfig, + args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"], + ) + + +def test_build_remote_config_applies_overrides(tmp_path): + cfg = _minimal_cfg() + dest = tmp_path / "train_config.json" + out = build_remote_config_file(cfg, "u/run", dest) + assert out == dest + data = json.loads(dest.read_text()) + # `job` is client-only orchestration and must be stripped for the pod. + assert "job" not in data + # save_checkpoint_to_hub defaults off → omitted so older images accept the config. + assert "save_checkpoint_to_hub" not in data + assert data["policy"]["push_to_hub"] is True + assert data["policy"]["repo_id"] == "u/run" + assert data["policy"]["device"] is None # pod auto-detects its GPU + assert data["dataset"]["root"] is None # pod resolves the dataset by repo_id + # the caller's cfg must be left untouched (function works on a deep copy) + assert cfg.job.target == "a10g-small" + assert cfg.save_checkpoint_to_hub is False + + +def test_build_remote_config_includes_checkpoint_flag_when_enabled(tmp_path): + cfg = draccus.parse( + TrainPipelineConfig, + args=[ + "--dataset.repo_id", + "u/d", + "--policy.type", + "act", + "--job.target", + "a10g-small", + "--save_checkpoint_to_hub", + "true", + ], + ) + dest = tmp_path / "train_config.json" + build_remote_config_file(cfg, "u/run", dest) + data = json.loads(dest.read_text()) + # explicitly enabled → kept in the config (requires a matching trainer image). + assert data["save_checkpoint_to_hub"] is True + assert "job" not in data + + +def test_build_remote_config_merges_tags_into_policy(tmp_path): + cfg = _minimal_cfg() + dest = tmp_path / "train_config.json" + build_remote_config_file(cfg, "u/run", dest, tags=["lerobot", "lelab"]) + data = json.loads(dest.read_text()) + # tags propagate to the model the pod pushes. + assert data["policy"]["tags"] == ["lerobot", "lelab"] + + +def test_build_remote_config_merges_tags_without_duplicating(tmp_path): + cfg = _minimal_cfg() + cfg.policy.tags = ["existing", "lerobot"] + dest = tmp_path / "train_config.json" + build_remote_config_file(cfg, "u/run", dest, tags=["lerobot", "lelab"]) + data = json.loads(dest.read_text()) + # pre-existing policy tags are kept; only genuinely-new tags are appended (no dup "lerobot"). + assert data["policy"]["tags"] == ["existing", "lerobot", "lelab"] + + +def test_submit_requires_login(monkeypatch): + monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: None) + cfg = draccus.parse( + TrainPipelineConfig, + args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"], + ) + with pytest.raises(RuntimeError, match="hf auth login"): + submit_to_hf(cfg) + + +def test_submit_passes_validation_and_submits(monkeypatch): + """Regression: repo_id must be set BEFORE cfg.validate() or validation raises.""" + from unittest.mock import MagicMock + + import huggingface_hub + + # Patch get_token + monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok") + + # Patch HfApi so whoami returns alice + class FakeHfApi: + def __init__(self, token=None): + pass + + def whoami(self, token=None): + return {"name": "alice"} + + monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi) + + # ensure_dataset_available returns None; patch it out so no Hub access happens + # (imported inside submit_to_hf via `from lerobot.jobs.dataset import ensure_dataset_available`). + monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None) + + # Patch _stage_config_on_hub to skip network + monkeypatch.setattr( + "lerobot.jobs.hf._stage_config_on_hub", + lambda cfg, repo_id, token, tags=None: repo_id, + ) + + # Patch run_job to return a fake job + fake_job = MagicMock() + fake_job.id = "job-123" + run_job_calls = [] + + def fake_run_job(**kwargs): + run_job_calls.append(kwargs) + return fake_job + + monkeypatch.setattr(huggingface_hub, "run_job", fake_run_job) + + cfg = draccus.parse( + TrainPipelineConfig, + args=[ + "--dataset.repo_id", + "u/d", + "--policy.type", + "act", + "--job.target", + "a10g-small", + "--job.detach", + "true", + ], + ) + + # Must NOT raise (pre-fix this raised ValueError about missing repo_id) + submit_to_hf(cfg) + + assert len(run_job_calls) == 1, "run_job should have been called exactly once" + assert cfg.policy.repo_id is not None + assert cfg.policy.repo_id.startswith("alice/") + call = run_job_calls[0] + # The pod runs `lerobot-train --config_path=` on the requested flavor/image. + assert call["command"][0] == "lerobot-train" + assert call["command"][1].startswith("--config_path=") + assert call["flavor"] == "a10g-small" + assert call["image"] == "huggingface/lerobot-gpu:latest" + # The Hub token is forwarded so the pod can pull the (possibly private) dataset. + assert call["secrets"]["HF_TOKEN"] == "tok" + # Every job carries the lerobot tag as a queryable label. + assert call["labels"].get("lerobot") == "true" + + +@pytest.mark.timeout(15) +def test_submit_returns_when_job_completes(monkeypatch): + """Non-detach path must RETURN (not hang) once the job reaches a terminal stage.""" + from types import SimpleNamespace + + import huggingface_hub + + monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok") + + class FakeHfApi: + def __init__(self, token=None): + pass + + def whoami(self, token=None): + return {"name": "alice"} + + monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi) + monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None) + monkeypatch.setattr( + "lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id + ) + monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x")) + # Job is already COMPLETED on the first poll. + monkeypatch.setattr( + "huggingface_hub.inspect_job", + lambda job_id: SimpleNamespace( + status=SimpleNamespace(stage=SimpleNamespace(value="COMPLETED"), message=None) + ), + ) + # Log stream ends immediately. + monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter(())) + + cfg = draccus.parse( + TrainPipelineConfig, + args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"], + ) + # Runs in the pytest main thread (signal handler install requires it); the + # @timeout marker fails the test instead of hanging if it regresses. + submit_to_hf(cfg) + + +@pytest.mark.timeout(15) +def test_submit_returns_on_model_pushed_marker(monkeypatch): + """Finish when the model-pushed log appears, even if the job stage never flips.""" + from types import SimpleNamespace + + import huggingface_hub + + monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok") + + class FakeHfApi: + def __init__(self, token=None): + pass + + def whoami(self, token=None): + return {"name": "alice"} + + monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi) + monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None) + monkeypatch.setattr( + "lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id + ) + monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x")) + # Job stays RUNNING forever — only the log marker can end the command. + monkeypatch.setattr( + "huggingface_hub.inspect_job", + lambda job_id: SimpleNamespace( + status=SimpleNamespace(stage=SimpleNamespace(value="RUNNING"), message=None) + ), + ) + pushed_line = "INFO Model pushed to https://huggingface.co/alice/myrun" + monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter([pushed_line])) + + cfg = draccus.parse( + TrainPipelineConfig, + args=[ + "--dataset.repo_id", + "u/d", + "--policy.type", + "act", + "--policy.repo_id", + "alice/myrun", + "--job.target", + "a10g-small", + ], + ) + # Must return via the model-pushed marker despite the perpetual RUNNING stage. + submit_to_hf(cfg) + + +def test_submit_raises_when_wandb_enabled_without_key(monkeypatch): + """wandb.enable with no key reachable anywhere fails fast, before submitting.""" + import huggingface_hub + + monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok") + + class FakeHfApi: + def __init__(self, token=None): + pass + + def whoami(self, token=None): + return {"name": "alice"} + + monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi) + monkeypatch.setattr("lerobot.jobs.hf.resolve_wandb_api_key", lambda: None) + + cfg = draccus.parse( + TrainPipelineConfig, + args=[ + "--dataset.repo_id", + "u/d", + "--policy.type", + "act", + "--job.target", + "a10g-small", + "--wandb.enable", + "true", + ], + ) + with pytest.raises(ValueError, match="WANDB_API_KEY"): + submit_to_hf(cfg) + + +@pytest.mark.timeout(15) +def test_submit_raises_when_job_ends_in_error(monkeypatch): + """A terminal non-COMPLETED stage with no model-pushed marker must raise with the status.""" + from types import SimpleNamespace + + import huggingface_hub + + monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok") + + class FakeHfApi: + def __init__(self, token=None): + pass + + def whoami(self, token=None): + return {"name": "alice"} + + monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi) + monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None) + monkeypatch.setattr( + "lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id + ) + monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x")) + # Job fails: a terminal ERROR stage carrying the platform's status message. + monkeypatch.setattr( + "huggingface_hub.inspect_job", + lambda job_id: SimpleNamespace( + status=SimpleNamespace(stage=SimpleNamespace(value="ERROR"), message="Job timeout") + ), + ) + # Logs end without the model-pushed marker. + monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter(())) + + cfg = draccus.parse( + TrainPipelineConfig, + args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"], + ) + with pytest.raises(RuntimeError, match=r"stage=ERROR \(Job timeout\)"): + submit_to_hf(cfg) diff --git a/tests/scripts/test_train_remote_dispatch.py b/tests/scripts/test_train_remote_dispatch.py new file mode 100644 index 000000000..3b634b563 --- /dev/null +++ b/tests/scripts/test_train_remote_dispatch.py @@ -0,0 +1,60 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +import draccus + +from lerobot.configs.train import TrainPipelineConfig +from lerobot.policies.act.configuration_act import ACTConfig # noqa: F401 (registers --policy.type act) +from lerobot.scripts.lerobot_train import _remote_target_in_argv, train + + +def _set_argv(monkeypatch, *args): + monkeypatch.setattr(sys, "argv", ["lerobot-train", *args]) + + +def test_remote_target_detected_space_separated(monkeypatch): + _set_argv(monkeypatch, "--policy.type", "act", "--job.target", "a10g-small") + assert _remote_target_in_argv() is True + + +def test_remote_target_detected_equals(monkeypatch): + _set_argv(monkeypatch, "--job.target=t4-small") + assert _remote_target_in_argv() is True + + +def test_local_string_is_not_remote(monkeypatch): + _set_argv(monkeypatch, "--job.target", "local") + assert _remote_target_in_argv() is False + + +def test_no_target_is_not_remote(monkeypatch): + _set_argv(monkeypatch, "--policy.type", "act") + assert _remote_target_in_argv() is False + + +def test_train_dispatches_to_submit_when_remote(monkeypatch): + """A remote --job.target short-circuits train() to the HF Jobs submitter.""" + import lerobot.jobs + + captured = [] + monkeypatch.setattr(lerobot.jobs, "submit_to_hf", lambda cfg: captured.append(cfg) or "submitted") + cfg = draccus.parse( + TrainPipelineConfig, + args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"], + ) + # Returns the submitter's result and never enters the local training path. + assert train(cfg) == "submitted" + assert captured == [cfg]