feat(train): run training remotely on HF Jobs via --job.target

When --job.target names a GPU flavor, train() dispatches to lerobot.jobs.submit_to_hf
instead of training locally: it authenticates, ensures the dataset is on the Hub
(pushing a local-only one privately), serializes a pod-compatible train_config.json
(strips client-only fields, points at the model repo), submits via HfApi.run_job
with HF_TOKEN/WANDB_API_KEY secrets, then streams logs and finishes when the model
is pushed. Wires push_checkpoint_to_hub into the training loop behind
save_checkpoint_to_hub, and tags jobs/datasets/model with 'lerobot' + --job.tags.
This commit is contained in:
Nicolas Rabault
2026-06-22 15:43:53 +02:00
parent 870d71aeb0
commit 607a8a6b68
7 changed files with 1002 additions and 0 deletions
+78
View File
@@ -0,0 +1,78 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from unittest.mock import MagicMock
import httpx
import pytest
from huggingface_hub.errors import RepositoryNotFoundError
from lerobot.jobs.dataset import ensure_dataset_available
def _repo_not_found() -> RepositoryNotFoundError:
req = httpx.Request("GET", "https://huggingface.co/datasets/test")
resp = httpx.Response(404, request=req)
return RepositoryNotFoundError("nope", response=resp)
def _api_with_dataset(exists: bool):
api = MagicMock()
if exists:
api.dataset_info.return_value = object()
else:
api.dataset_info.side_effect = _repo_not_found()
return api
def _make_local_cache(tmp_path, repo_id: str) -> None:
"""Create the minimal local-cache layout that ensure_dataset_available checks."""
info = tmp_path / repo_id / "meta" / "info.json"
info.parent.mkdir(parents=True)
info.write_text("{}")
# Branch 1: dataset already on Hub → no push, no error (pod downloads by repo_id).
def test_dataset_already_on_hub_is_noop():
api = _api_with_dataset(True)
assert ensure_dataset_available("user/ds", api=api) is None
api.dataset_info.assert_called_once_with("user/ds")
# Branch 2: not on Hub but present locally → always push privately.
def test_dataset_local_only_uploads_privately(tmp_path, monkeypatch):
monkeypatch.setenv("HF_LEROBOT_HOME", str(tmp_path))
_make_local_cache(tmp_path, "user/ds")
api = _api_with_dataset(False)
mock_ds_cls = MagicMock()
fake_datasets_module = MagicMock()
fake_datasets_module.LeRobotDataset = mock_ds_cls
monkeypatch.setitem(sys.modules, "lerobot.datasets", fake_datasets_module)
assert ensure_dataset_available("user/ds", api=api, tags=["lerobot", "lelab"]) is None
mock_ds_cls.assert_called_once_with("user/ds")
mock_ds_cls.return_value.push_to_hub.assert_called_once_with(private=True, tags=["lerobot", "lelab"])
# Branch 3: not on Hub, NOT in local cache → RuntimeError "neither".
def test_dataset_neither_on_hub_nor_local_raises(tmp_path, monkeypatch):
monkeypatch.setenv("HF_LEROBOT_HOME", str(tmp_path))
# tmp_path is empty — no local cache.
api = _api_with_dataset(False)
with pytest.raises(RuntimeError, match="neither"):
ensure_dataset_available("user/ds", api=api)
+426
View File
@@ -0,0 +1,426 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime as dt
import json
import threading
from types import SimpleNamespace
import draccus
import pytest
from lerobot.configs.train import TrainPipelineConfig
from lerobot.jobs.hf import (
_poll_until_done,
build_remote_config_file,
build_repo_id,
resolve_job_tags,
resolve_wandb_api_key,
submit_to_hf,
)
def test_resolve_job_tags_always_includes_lerobot_and_dedups():
assert resolve_job_tags(None) == ["lerobot"]
assert resolve_job_tags([]) == ["lerobot"]
assert resolve_job_tags(["lelab"]) == ["lerobot", "lelab"]
# lerobot isn't duplicated if passed explicitly; order is stable.
assert resolve_job_tags(["lelab", "lerobot", "lelab"]) == ["lerobot", "lelab"]
def _fake_inspect(stage_value):
return lambda job_id: SimpleNamespace(status=SimpleNamespace(stage=SimpleNamespace(value=stage_value)))
def test_poll_until_done_returns_terminal_stage(monkeypatch):
monkeypatch.setattr("huggingface_hub.inspect_job", _fake_inspect("COMPLETED"))
done = threading.Event()
assert _poll_until_done("j", done, poll_interval=0.01) == "COMPLETED"
assert done.is_set()
def test_poll_until_done_exits_when_done_already_set(monkeypatch):
# Non-terminal forever; with done pre-set the loop must not block and returns None.
monkeypatch.setattr("huggingface_hub.inspect_job", _fake_inspect("RUNNING"))
done = threading.Event()
done.set()
assert _poll_until_done("j", done, poll_interval=0.01) is None
def test_poll_until_done_gives_up_after_repeated_failures(monkeypatch):
monkeypatch.setattr(
"huggingface_hub.inspect_job", lambda job_id: (_ for _ in ()).throw(RuntimeError("boom"))
)
done = threading.Event()
result = _poll_until_done("j", done, poll_interval=0.001, max_failures=3)
assert result is None
assert done.is_set()
def test_resolve_wandb_key_from_env(monkeypatch):
monkeypatch.setenv("WANDB_API_KEY", "abc123")
assert resolve_wandb_api_key() == "abc123"
def test_resolve_wandb_key_missing(monkeypatch, tmp_path):
monkeypatch.delenv("WANDB_API_KEY", raising=False)
monkeypatch.setenv("HOME", str(tmp_path)) # no ~/.netrc here
monkeypatch.setattr("netrc.netrc", lambda *a, **k: (_ for _ in ()).throw(FileNotFoundError()))
assert resolve_wandb_api_key() is None
def test_resolve_wandb_key_from_netrc(monkeypatch):
# No env var → fall back to the wandb credentials in ~/.netrc.
monkeypatch.delenv("WANDB_API_KEY", raising=False)
class _FakeNetrc:
def authenticators(self, host):
assert host == "api.wandb.ai"
return ("login", "account", "netrc-secret")
monkeypatch.setattr("netrc.netrc", lambda *a, **k: _FakeNetrc())
assert resolve_wandb_api_key() == "netrc-secret"
def test_resolve_wandb_key_netrc_without_wandb_entry(monkeypatch):
# ~/.netrc exists but has no api.wandb.ai entry → None.
monkeypatch.delenv("WANDB_API_KEY", raising=False)
class _FakeNetrc:
def authenticators(self, host):
return None
monkeypatch.setattr("netrc.netrc", lambda *a, **k: _FakeNetrc())
assert resolve_wandb_api_key() is None
def test_build_repo_id_sanitizes_and_timestamps():
now = dt.datetime(2026, 6, 19, 10, 22, 3)
assert build_repo_id("alice", "act", now) == "alice/act_2026-06-19_10-22-03"
# Runs of illegal characters collapse to a single dash; edges are trimmed.
assert build_repo_id("alice", "my cool/run!!", now) == "alice/my-cool-run_2026-06-19_10-22-03"
# A name with nothing usable falls back to "train".
assert build_repo_id("alice", "///", now) == "alice/train_2026-06-19_10-22-03"
def _minimal_cfg():
return draccus.parse(
TrainPipelineConfig,
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
)
def test_build_remote_config_applies_overrides(tmp_path):
cfg = _minimal_cfg()
dest = tmp_path / "train_config.json"
out = build_remote_config_file(cfg, "u/run", dest)
assert out == dest
data = json.loads(dest.read_text())
# `job` is client-only orchestration and must be stripped for the pod.
assert "job" not in data
# save_checkpoint_to_hub defaults off → omitted so older images accept the config.
assert "save_checkpoint_to_hub" not in data
assert data["policy"]["push_to_hub"] is True
assert data["policy"]["repo_id"] == "u/run"
assert data["policy"]["device"] is None # pod auto-detects its GPU
assert data["dataset"]["root"] is None # pod resolves the dataset by repo_id
# the caller's cfg must be left untouched (function works on a deep copy)
assert cfg.job.target == "a10g-small"
assert cfg.save_checkpoint_to_hub is False
def test_build_remote_config_includes_checkpoint_flag_when_enabled(tmp_path):
cfg = draccus.parse(
TrainPipelineConfig,
args=[
"--dataset.repo_id",
"u/d",
"--policy.type",
"act",
"--job.target",
"a10g-small",
"--save_checkpoint_to_hub",
"true",
],
)
dest = tmp_path / "train_config.json"
build_remote_config_file(cfg, "u/run", dest)
data = json.loads(dest.read_text())
# explicitly enabled → kept in the config (requires a matching trainer image).
assert data["save_checkpoint_to_hub"] is True
assert "job" not in data
def test_build_remote_config_merges_tags_into_policy(tmp_path):
cfg = _minimal_cfg()
dest = tmp_path / "train_config.json"
build_remote_config_file(cfg, "u/run", dest, tags=["lerobot", "lelab"])
data = json.loads(dest.read_text())
# tags propagate to the model the pod pushes.
assert data["policy"]["tags"] == ["lerobot", "lelab"]
def test_build_remote_config_merges_tags_without_duplicating(tmp_path):
cfg = _minimal_cfg()
cfg.policy.tags = ["existing", "lerobot"]
dest = tmp_path / "train_config.json"
build_remote_config_file(cfg, "u/run", dest, tags=["lerobot", "lelab"])
data = json.loads(dest.read_text())
# pre-existing policy tags are kept; only genuinely-new tags are appended (no dup "lerobot").
assert data["policy"]["tags"] == ["existing", "lerobot", "lelab"]
def test_submit_requires_login(monkeypatch):
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: None)
cfg = draccus.parse(
TrainPipelineConfig,
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
)
with pytest.raises(RuntimeError, match="hf auth login"):
submit_to_hf(cfg)
def test_submit_passes_validation_and_submits(monkeypatch):
"""Regression: repo_id must be set BEFORE cfg.validate() or validation raises."""
from unittest.mock import MagicMock
import huggingface_hub
# Patch get_token
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
# Patch HfApi so whoami returns alice
class FakeHfApi:
def __init__(self, token=None):
pass
def whoami(self, token=None):
return {"name": "alice"}
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
# ensure_dataset_available returns None; patch it out so no Hub access happens
# (imported inside submit_to_hf via `from lerobot.jobs.dataset import ensure_dataset_available`).
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
# Patch _stage_config_on_hub to skip network
monkeypatch.setattr(
"lerobot.jobs.hf._stage_config_on_hub",
lambda cfg, repo_id, token, tags=None: repo_id,
)
# Patch run_job to return a fake job
fake_job = MagicMock()
fake_job.id = "job-123"
run_job_calls = []
def fake_run_job(**kwargs):
run_job_calls.append(kwargs)
return fake_job
monkeypatch.setattr(huggingface_hub, "run_job", fake_run_job)
cfg = draccus.parse(
TrainPipelineConfig,
args=[
"--dataset.repo_id",
"u/d",
"--policy.type",
"act",
"--job.target",
"a10g-small",
"--job.detach",
"true",
],
)
# Must NOT raise (pre-fix this raised ValueError about missing repo_id)
submit_to_hf(cfg)
assert len(run_job_calls) == 1, "run_job should have been called exactly once"
assert cfg.policy.repo_id is not None
assert cfg.policy.repo_id.startswith("alice/")
call = run_job_calls[0]
# The pod runs `lerobot-train --config_path=<staged repo>` on the requested flavor/image.
assert call["command"][0] == "lerobot-train"
assert call["command"][1].startswith("--config_path=")
assert call["flavor"] == "a10g-small"
assert call["image"] == "huggingface/lerobot-gpu:latest"
# The Hub token is forwarded so the pod can pull the (possibly private) dataset.
assert call["secrets"]["HF_TOKEN"] == "tok"
# Every job carries the lerobot tag as a queryable label.
assert call["labels"].get("lerobot") == "true"
@pytest.mark.timeout(15)
def test_submit_returns_when_job_completes(monkeypatch):
"""Non-detach path must RETURN (not hang) once the job reaches a terminal stage."""
from types import SimpleNamespace
import huggingface_hub
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
class FakeHfApi:
def __init__(self, token=None):
pass
def whoami(self, token=None):
return {"name": "alice"}
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
monkeypatch.setattr(
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
)
monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
# Job is already COMPLETED on the first poll.
monkeypatch.setattr(
"huggingface_hub.inspect_job",
lambda job_id: SimpleNamespace(
status=SimpleNamespace(stage=SimpleNamespace(value="COMPLETED"), message=None)
),
)
# Log stream ends immediately.
monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter(()))
cfg = draccus.parse(
TrainPipelineConfig,
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
)
# Runs in the pytest main thread (signal handler install requires it); the
# @timeout marker fails the test instead of hanging if it regresses.
submit_to_hf(cfg)
@pytest.mark.timeout(15)
def test_submit_returns_on_model_pushed_marker(monkeypatch):
"""Finish when the model-pushed log appears, even if the job stage never flips."""
from types import SimpleNamespace
import huggingface_hub
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
class FakeHfApi:
def __init__(self, token=None):
pass
def whoami(self, token=None):
return {"name": "alice"}
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
monkeypatch.setattr(
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
)
monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
# Job stays RUNNING forever — only the log marker can end the command.
monkeypatch.setattr(
"huggingface_hub.inspect_job",
lambda job_id: SimpleNamespace(
status=SimpleNamespace(stage=SimpleNamespace(value="RUNNING"), message=None)
),
)
pushed_line = "INFO Model pushed to https://huggingface.co/alice/myrun"
monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter([pushed_line]))
cfg = draccus.parse(
TrainPipelineConfig,
args=[
"--dataset.repo_id",
"u/d",
"--policy.type",
"act",
"--policy.repo_id",
"alice/myrun",
"--job.target",
"a10g-small",
],
)
# Must return via the model-pushed marker despite the perpetual RUNNING stage.
submit_to_hf(cfg)
def test_submit_raises_when_wandb_enabled_without_key(monkeypatch):
"""wandb.enable with no key reachable anywhere fails fast, before submitting."""
import huggingface_hub
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
class FakeHfApi:
def __init__(self, token=None):
pass
def whoami(self, token=None):
return {"name": "alice"}
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
monkeypatch.setattr("lerobot.jobs.hf.resolve_wandb_api_key", lambda: None)
cfg = draccus.parse(
TrainPipelineConfig,
args=[
"--dataset.repo_id",
"u/d",
"--policy.type",
"act",
"--job.target",
"a10g-small",
"--wandb.enable",
"true",
],
)
with pytest.raises(ValueError, match="WANDB_API_KEY"):
submit_to_hf(cfg)
@pytest.mark.timeout(15)
def test_submit_raises_when_job_ends_in_error(monkeypatch):
"""A terminal non-COMPLETED stage with no model-pushed marker must raise with the status."""
from types import SimpleNamespace
import huggingface_hub
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
class FakeHfApi:
def __init__(self, token=None):
pass
def whoami(self, token=None):
return {"name": "alice"}
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
monkeypatch.setattr(
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
)
monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
# Job fails: a terminal ERROR stage carrying the platform's status message.
monkeypatch.setattr(
"huggingface_hub.inspect_job",
lambda job_id: SimpleNamespace(
status=SimpleNamespace(stage=SimpleNamespace(value="ERROR"), message="Job timeout")
),
)
# Logs end without the model-pushed marker.
monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter(()))
cfg = draccus.parse(
TrainPipelineConfig,
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
)
with pytest.raises(RuntimeError, match=r"stage=ERROR \(Job timeout\)"):
submit_to_hf(cfg)
@@ -0,0 +1,60 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import draccus
from lerobot.configs.train import TrainPipelineConfig
from lerobot.policies.act.configuration_act import ACTConfig # noqa: F401 (registers --policy.type act)
from lerobot.scripts.lerobot_train import _remote_target_in_argv, train
def _set_argv(monkeypatch, *args):
monkeypatch.setattr(sys, "argv", ["lerobot-train", *args])
def test_remote_target_detected_space_separated(monkeypatch):
_set_argv(monkeypatch, "--policy.type", "act", "--job.target", "a10g-small")
assert _remote_target_in_argv() is True
def test_remote_target_detected_equals(monkeypatch):
_set_argv(monkeypatch, "--job.target=t4-small")
assert _remote_target_in_argv() is True
def test_local_string_is_not_remote(monkeypatch):
_set_argv(monkeypatch, "--job.target", "local")
assert _remote_target_in_argv() is False
def test_no_target_is_not_remote(monkeypatch):
_set_argv(monkeypatch, "--policy.type", "act")
assert _remote_target_in_argv() is False
def test_train_dispatches_to_submit_when_remote(monkeypatch):
"""A remote --job.target short-circuits train() to the HF Jobs submitter."""
import lerobot.jobs
captured = []
monkeypatch.setattr(lerobot.jobs, "submit_to_hf", lambda cfg: captured.append(cfg) or "submitted")
cfg = draccus.parse(
TrainPipelineConfig,
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
)
# Returns the submitter's result and never enters the local training path.
assert train(cfg) == "submitted"
assert captured == [cfg]