mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-01 15:17:05 +00:00
feat(train): run training remotely on HF Jobs via --job.target
When --job.target names a GPU flavor, train() dispatches to lerobot.jobs.submit_to_hf instead of training locally: it authenticates, ensures the dataset is on the Hub (pushing a local-only one privately), serializes a pod-compatible train_config.json (strips client-only fields, points at the model repo), submits via HfApi.run_job with HF_TOKEN/WANDB_API_KEY secrets, then streams logs and finishes when the model is pushed. Wires push_checkpoint_to_hub into the training loop behind save_checkpoint_to_hub, and tags jobs/datasets/model with 'lerobot' + --job.tags.
This commit is contained in:
@@ -0,0 +1,78 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from huggingface_hub.errors import RepositoryNotFoundError
|
||||
|
||||
from lerobot.jobs.dataset import ensure_dataset_available
|
||||
|
||||
|
||||
def _repo_not_found() -> RepositoryNotFoundError:
|
||||
req = httpx.Request("GET", "https://huggingface.co/datasets/test")
|
||||
resp = httpx.Response(404, request=req)
|
||||
return RepositoryNotFoundError("nope", response=resp)
|
||||
|
||||
|
||||
def _api_with_dataset(exists: bool):
|
||||
api = MagicMock()
|
||||
if exists:
|
||||
api.dataset_info.return_value = object()
|
||||
else:
|
||||
api.dataset_info.side_effect = _repo_not_found()
|
||||
return api
|
||||
|
||||
|
||||
def _make_local_cache(tmp_path, repo_id: str) -> None:
|
||||
"""Create the minimal local-cache layout that ensure_dataset_available checks."""
|
||||
info = tmp_path / repo_id / "meta" / "info.json"
|
||||
info.parent.mkdir(parents=True)
|
||||
info.write_text("{}")
|
||||
|
||||
|
||||
# Branch 1: dataset already on Hub → no push, no error (pod downloads by repo_id).
|
||||
def test_dataset_already_on_hub_is_noop():
|
||||
api = _api_with_dataset(True)
|
||||
assert ensure_dataset_available("user/ds", api=api) is None
|
||||
api.dataset_info.assert_called_once_with("user/ds")
|
||||
|
||||
|
||||
# Branch 2: not on Hub but present locally → always push privately.
|
||||
def test_dataset_local_only_uploads_privately(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HF_LEROBOT_HOME", str(tmp_path))
|
||||
_make_local_cache(tmp_path, "user/ds")
|
||||
|
||||
api = _api_with_dataset(False)
|
||||
mock_ds_cls = MagicMock()
|
||||
fake_datasets_module = MagicMock()
|
||||
fake_datasets_module.LeRobotDataset = mock_ds_cls
|
||||
monkeypatch.setitem(sys.modules, "lerobot.datasets", fake_datasets_module)
|
||||
|
||||
assert ensure_dataset_available("user/ds", api=api, tags=["lerobot", "lelab"]) is None
|
||||
|
||||
mock_ds_cls.assert_called_once_with("user/ds")
|
||||
mock_ds_cls.return_value.push_to_hub.assert_called_once_with(private=True, tags=["lerobot", "lelab"])
|
||||
|
||||
|
||||
# Branch 3: not on Hub, NOT in local cache → RuntimeError "neither".
|
||||
def test_dataset_neither_on_hub_nor_local_raises(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HF_LEROBOT_HOME", str(tmp_path))
|
||||
# tmp_path is empty — no local cache.
|
||||
|
||||
api = _api_with_dataset(False)
|
||||
with pytest.raises(RuntimeError, match="neither"):
|
||||
ensure_dataset_available("user/ds", api=api)
|
||||
@@ -0,0 +1,426 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime as dt
|
||||
import json
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
|
||||
import draccus
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.jobs.hf import (
|
||||
_poll_until_done,
|
||||
build_remote_config_file,
|
||||
build_repo_id,
|
||||
resolve_job_tags,
|
||||
resolve_wandb_api_key,
|
||||
submit_to_hf,
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_job_tags_always_includes_lerobot_and_dedups():
|
||||
assert resolve_job_tags(None) == ["lerobot"]
|
||||
assert resolve_job_tags([]) == ["lerobot"]
|
||||
assert resolve_job_tags(["lelab"]) == ["lerobot", "lelab"]
|
||||
# lerobot isn't duplicated if passed explicitly; order is stable.
|
||||
assert resolve_job_tags(["lelab", "lerobot", "lelab"]) == ["lerobot", "lelab"]
|
||||
|
||||
|
||||
def _fake_inspect(stage_value):
|
||||
return lambda job_id: SimpleNamespace(status=SimpleNamespace(stage=SimpleNamespace(value=stage_value)))
|
||||
|
||||
|
||||
def test_poll_until_done_returns_terminal_stage(monkeypatch):
|
||||
monkeypatch.setattr("huggingface_hub.inspect_job", _fake_inspect("COMPLETED"))
|
||||
done = threading.Event()
|
||||
assert _poll_until_done("j", done, poll_interval=0.01) == "COMPLETED"
|
||||
assert done.is_set()
|
||||
|
||||
|
||||
def test_poll_until_done_exits_when_done_already_set(monkeypatch):
|
||||
# Non-terminal forever; with done pre-set the loop must not block and returns None.
|
||||
monkeypatch.setattr("huggingface_hub.inspect_job", _fake_inspect("RUNNING"))
|
||||
done = threading.Event()
|
||||
done.set()
|
||||
assert _poll_until_done("j", done, poll_interval=0.01) is None
|
||||
|
||||
|
||||
def test_poll_until_done_gives_up_after_repeated_failures(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"huggingface_hub.inspect_job", lambda job_id: (_ for _ in ()).throw(RuntimeError("boom"))
|
||||
)
|
||||
done = threading.Event()
|
||||
result = _poll_until_done("j", done, poll_interval=0.001, max_failures=3)
|
||||
assert result is None
|
||||
assert done.is_set()
|
||||
|
||||
|
||||
def test_resolve_wandb_key_from_env(monkeypatch):
|
||||
monkeypatch.setenv("WANDB_API_KEY", "abc123")
|
||||
assert resolve_wandb_api_key() == "abc123"
|
||||
|
||||
|
||||
def test_resolve_wandb_key_missing(monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("WANDB_API_KEY", raising=False)
|
||||
monkeypatch.setenv("HOME", str(tmp_path)) # no ~/.netrc here
|
||||
monkeypatch.setattr("netrc.netrc", lambda *a, **k: (_ for _ in ()).throw(FileNotFoundError()))
|
||||
assert resolve_wandb_api_key() is None
|
||||
|
||||
|
||||
def test_resolve_wandb_key_from_netrc(monkeypatch):
|
||||
# No env var → fall back to the wandb credentials in ~/.netrc.
|
||||
monkeypatch.delenv("WANDB_API_KEY", raising=False)
|
||||
|
||||
class _FakeNetrc:
|
||||
def authenticators(self, host):
|
||||
assert host == "api.wandb.ai"
|
||||
return ("login", "account", "netrc-secret")
|
||||
|
||||
monkeypatch.setattr("netrc.netrc", lambda *a, **k: _FakeNetrc())
|
||||
assert resolve_wandb_api_key() == "netrc-secret"
|
||||
|
||||
|
||||
def test_resolve_wandb_key_netrc_without_wandb_entry(monkeypatch):
|
||||
# ~/.netrc exists but has no api.wandb.ai entry → None.
|
||||
monkeypatch.delenv("WANDB_API_KEY", raising=False)
|
||||
|
||||
class _FakeNetrc:
|
||||
def authenticators(self, host):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("netrc.netrc", lambda *a, **k: _FakeNetrc())
|
||||
assert resolve_wandb_api_key() is None
|
||||
|
||||
|
||||
def test_build_repo_id_sanitizes_and_timestamps():
|
||||
now = dt.datetime(2026, 6, 19, 10, 22, 3)
|
||||
assert build_repo_id("alice", "act", now) == "alice/act_2026-06-19_10-22-03"
|
||||
# Runs of illegal characters collapse to a single dash; edges are trimmed.
|
||||
assert build_repo_id("alice", "my cool/run!!", now) == "alice/my-cool-run_2026-06-19_10-22-03"
|
||||
# A name with nothing usable falls back to "train".
|
||||
assert build_repo_id("alice", "///", now) == "alice/train_2026-06-19_10-22-03"
|
||||
|
||||
|
||||
def _minimal_cfg():
|
||||
return draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
|
||||
|
||||
def test_build_remote_config_applies_overrides(tmp_path):
|
||||
cfg = _minimal_cfg()
|
||||
dest = tmp_path / "train_config.json"
|
||||
out = build_remote_config_file(cfg, "u/run", dest)
|
||||
assert out == dest
|
||||
data = json.loads(dest.read_text())
|
||||
# `job` is client-only orchestration and must be stripped for the pod.
|
||||
assert "job" not in data
|
||||
# save_checkpoint_to_hub defaults off → omitted so older images accept the config.
|
||||
assert "save_checkpoint_to_hub" not in data
|
||||
assert data["policy"]["push_to_hub"] is True
|
||||
assert data["policy"]["repo_id"] == "u/run"
|
||||
assert data["policy"]["device"] is None # pod auto-detects its GPU
|
||||
assert data["dataset"]["root"] is None # pod resolves the dataset by repo_id
|
||||
# the caller's cfg must be left untouched (function works on a deep copy)
|
||||
assert cfg.job.target == "a10g-small"
|
||||
assert cfg.save_checkpoint_to_hub is False
|
||||
|
||||
|
||||
def test_build_remote_config_includes_checkpoint_flag_when_enabled(tmp_path):
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=[
|
||||
"--dataset.repo_id",
|
||||
"u/d",
|
||||
"--policy.type",
|
||||
"act",
|
||||
"--job.target",
|
||||
"a10g-small",
|
||||
"--save_checkpoint_to_hub",
|
||||
"true",
|
||||
],
|
||||
)
|
||||
dest = tmp_path / "train_config.json"
|
||||
build_remote_config_file(cfg, "u/run", dest)
|
||||
data = json.loads(dest.read_text())
|
||||
# explicitly enabled → kept in the config (requires a matching trainer image).
|
||||
assert data["save_checkpoint_to_hub"] is True
|
||||
assert "job" not in data
|
||||
|
||||
|
||||
def test_build_remote_config_merges_tags_into_policy(tmp_path):
|
||||
cfg = _minimal_cfg()
|
||||
dest = tmp_path / "train_config.json"
|
||||
build_remote_config_file(cfg, "u/run", dest, tags=["lerobot", "lelab"])
|
||||
data = json.loads(dest.read_text())
|
||||
# tags propagate to the model the pod pushes.
|
||||
assert data["policy"]["tags"] == ["lerobot", "lelab"]
|
||||
|
||||
|
||||
def test_build_remote_config_merges_tags_without_duplicating(tmp_path):
|
||||
cfg = _minimal_cfg()
|
||||
cfg.policy.tags = ["existing", "lerobot"]
|
||||
dest = tmp_path / "train_config.json"
|
||||
build_remote_config_file(cfg, "u/run", dest, tags=["lerobot", "lelab"])
|
||||
data = json.loads(dest.read_text())
|
||||
# pre-existing policy tags are kept; only genuinely-new tags are appended (no dup "lerobot").
|
||||
assert data["policy"]["tags"] == ["existing", "lerobot", "lelab"]
|
||||
|
||||
|
||||
def test_submit_requires_login(monkeypatch):
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: None)
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="hf auth login"):
|
||||
submit_to_hf(cfg)
|
||||
|
||||
|
||||
def test_submit_passes_validation_and_submits(monkeypatch):
|
||||
"""Regression: repo_id must be set BEFORE cfg.validate() or validation raises."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
# Patch get_token
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
# Patch HfApi so whoami returns alice
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
|
||||
|
||||
# ensure_dataset_available returns None; patch it out so no Hub access happens
|
||||
# (imported inside submit_to_hf via `from lerobot.jobs.dataset import ensure_dataset_available`).
|
||||
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
|
||||
|
||||
# Patch _stage_config_on_hub to skip network
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf._stage_config_on_hub",
|
||||
lambda cfg, repo_id, token, tags=None: repo_id,
|
||||
)
|
||||
|
||||
# Patch run_job to return a fake job
|
||||
fake_job = MagicMock()
|
||||
fake_job.id = "job-123"
|
||||
run_job_calls = []
|
||||
|
||||
def fake_run_job(**kwargs):
|
||||
run_job_calls.append(kwargs)
|
||||
return fake_job
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "run_job", fake_run_job)
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=[
|
||||
"--dataset.repo_id",
|
||||
"u/d",
|
||||
"--policy.type",
|
||||
"act",
|
||||
"--job.target",
|
||||
"a10g-small",
|
||||
"--job.detach",
|
||||
"true",
|
||||
],
|
||||
)
|
||||
|
||||
# Must NOT raise (pre-fix this raised ValueError about missing repo_id)
|
||||
submit_to_hf(cfg)
|
||||
|
||||
assert len(run_job_calls) == 1, "run_job should have been called exactly once"
|
||||
assert cfg.policy.repo_id is not None
|
||||
assert cfg.policy.repo_id.startswith("alice/")
|
||||
call = run_job_calls[0]
|
||||
# The pod runs `lerobot-train --config_path=<staged repo>` on the requested flavor/image.
|
||||
assert call["command"][0] == "lerobot-train"
|
||||
assert call["command"][1].startswith("--config_path=")
|
||||
assert call["flavor"] == "a10g-small"
|
||||
assert call["image"] == "huggingface/lerobot-gpu:latest"
|
||||
# The Hub token is forwarded so the pod can pull the (possibly private) dataset.
|
||||
assert call["secrets"]["HF_TOKEN"] == "tok"
|
||||
# Every job carries the lerobot tag as a queryable label.
|
||||
assert call["labels"].get("lerobot") == "true"
|
||||
|
||||
|
||||
@pytest.mark.timeout(15)
|
||||
def test_submit_returns_when_job_completes(monkeypatch):
|
||||
"""Non-detach path must RETURN (not hang) once the job reaches a terminal stage."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
|
||||
)
|
||||
monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
|
||||
# Job is already COMPLETED on the first poll.
|
||||
monkeypatch.setattr(
|
||||
"huggingface_hub.inspect_job",
|
||||
lambda job_id: SimpleNamespace(
|
||||
status=SimpleNamespace(stage=SimpleNamespace(value="COMPLETED"), message=None)
|
||||
),
|
||||
)
|
||||
# Log stream ends immediately.
|
||||
monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter(()))
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
# Runs in the pytest main thread (signal handler install requires it); the
|
||||
# @timeout marker fails the test instead of hanging if it regresses.
|
||||
submit_to_hf(cfg)
|
||||
|
||||
|
||||
@pytest.mark.timeout(15)
|
||||
def test_submit_returns_on_model_pushed_marker(monkeypatch):
|
||||
"""Finish when the model-pushed log appears, even if the job stage never flips."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
|
||||
)
|
||||
monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
|
||||
# Job stays RUNNING forever — only the log marker can end the command.
|
||||
monkeypatch.setattr(
|
||||
"huggingface_hub.inspect_job",
|
||||
lambda job_id: SimpleNamespace(
|
||||
status=SimpleNamespace(stage=SimpleNamespace(value="RUNNING"), message=None)
|
||||
),
|
||||
)
|
||||
pushed_line = "INFO Model pushed to https://huggingface.co/alice/myrun"
|
||||
monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter([pushed_line]))
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=[
|
||||
"--dataset.repo_id",
|
||||
"u/d",
|
||||
"--policy.type",
|
||||
"act",
|
||||
"--policy.repo_id",
|
||||
"alice/myrun",
|
||||
"--job.target",
|
||||
"a10g-small",
|
||||
],
|
||||
)
|
||||
# Must return via the model-pushed marker despite the perpetual RUNNING stage.
|
||||
submit_to_hf(cfg)
|
||||
|
||||
|
||||
def test_submit_raises_when_wandb_enabled_without_key(monkeypatch):
|
||||
"""wandb.enable with no key reachable anywhere fails fast, before submitting."""
|
||||
import huggingface_hub
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.hf.resolve_wandb_api_key", lambda: None)
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=[
|
||||
"--dataset.repo_id",
|
||||
"u/d",
|
||||
"--policy.type",
|
||||
"act",
|
||||
"--job.target",
|
||||
"a10g-small",
|
||||
"--wandb.enable",
|
||||
"true",
|
||||
],
|
||||
)
|
||||
with pytest.raises(ValueError, match="WANDB_API_KEY"):
|
||||
submit_to_hf(cfg)
|
||||
|
||||
|
||||
@pytest.mark.timeout(15)
|
||||
def test_submit_raises_when_job_ends_in_error(monkeypatch):
|
||||
"""A terminal non-COMPLETED stage with no model-pushed marker must raise with the status."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
|
||||
)
|
||||
monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
|
||||
# Job fails: a terminal ERROR stage carrying the platform's status message.
|
||||
monkeypatch.setattr(
|
||||
"huggingface_hub.inspect_job",
|
||||
lambda job_id: SimpleNamespace(
|
||||
status=SimpleNamespace(stage=SimpleNamespace(value="ERROR"), message="Job timeout")
|
||||
),
|
||||
)
|
||||
# Logs end without the model-pushed marker.
|
||||
monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter(()))
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
with pytest.raises(RuntimeError, match=r"stage=ERROR \(Job timeout\)"):
|
||||
submit_to_hf(cfg)
|
||||
@@ -0,0 +1,60 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.policies.act.configuration_act import ACTConfig # noqa: F401 (registers --policy.type act)
|
||||
from lerobot.scripts.lerobot_train import _remote_target_in_argv, train
|
||||
|
||||
|
||||
def _set_argv(monkeypatch, *args):
|
||||
monkeypatch.setattr(sys, "argv", ["lerobot-train", *args])
|
||||
|
||||
|
||||
def test_remote_target_detected_space_separated(monkeypatch):
|
||||
_set_argv(monkeypatch, "--policy.type", "act", "--job.target", "a10g-small")
|
||||
assert _remote_target_in_argv() is True
|
||||
|
||||
|
||||
def test_remote_target_detected_equals(monkeypatch):
|
||||
_set_argv(monkeypatch, "--job.target=t4-small")
|
||||
assert _remote_target_in_argv() is True
|
||||
|
||||
|
||||
def test_local_string_is_not_remote(monkeypatch):
|
||||
_set_argv(monkeypatch, "--job.target", "local")
|
||||
assert _remote_target_in_argv() is False
|
||||
|
||||
|
||||
def test_no_target_is_not_remote(monkeypatch):
|
||||
_set_argv(monkeypatch, "--policy.type", "act")
|
||||
assert _remote_target_in_argv() is False
|
||||
|
||||
|
||||
def test_train_dispatches_to_submit_when_remote(monkeypatch):
|
||||
"""A remote --job.target short-circuits train() to the HF Jobs submitter."""
|
||||
import lerobot.jobs
|
||||
|
||||
captured = []
|
||||
monkeypatch.setattr(lerobot.jobs, "submit_to_hf", lambda cfg: captured.append(cfg) or "submitted")
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
# Returns the submitter's result and never enters the local training path.
|
||||
assert train(cfg) == "submitted"
|
||||
assert captured == [cfg]
|
||||
Reference in New Issue
Block a user