mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-30 14:47:10 +00:00
feat(train): run training remotely on HF Jobs via --job.target
When --job.target names a GPU flavor, train() dispatches to lerobot.jobs.submit_to_hf instead of training locally: it authenticates, ensures the dataset is on the Hub (pushing a local-only one privately), serializes a pod-compatible train_config.json (strips client-only fields, points at the model repo), submits via HfApi.run_job with HF_TOKEN/WANDB_API_KEY secrets, then streams logs and finishes when the model is pushed. Wires push_checkpoint_to_hub into the training loop behind save_checkpoint_to_hub, and tags jobs/datasets/model with 'lerobot' + --job.tags.
This commit is contained in:
@@ -0,0 +1,17 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .hf import submit_to_hf
|
||||
|
||||
__all__ = ["submit_to_hf"]
|
||||
@@ -0,0 +1,56 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Make a training dataset reachable from an HF Job pod.
|
||||
|
||||
The pod can't see the host's ~/.cache/huggingface/lerobot, so the dataset has to
|
||||
live on the Hub: the pod downloads it by repo_id at train time (the forwarded
|
||||
HF_TOKEN covers private datasets). A dataset already on the Hub is used as-is; a
|
||||
local-only dataset is pushed to a PRIVATE repo first (never public).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub.errors import RepositoryNotFoundError
|
||||
|
||||
|
||||
def ensure_dataset_available(repo_id: str, *, api, tags: list[str] | None = None) -> None:
|
||||
"""Ensure repo_id resolves on the Hub, pushing a local-only dataset privately first.
|
||||
|
||||
`tags` are attached to the dataset only when we push it (an already-on-Hub
|
||||
dataset is left untouched). Raises RuntimeError if the dataset is neither on
|
||||
the Hub nor in the local cache.
|
||||
"""
|
||||
try:
|
||||
api.dataset_info(repo_id)
|
||||
return
|
||||
except RepositoryNotFoundError:
|
||||
pass
|
||||
|
||||
cache_root = Path(os.environ.get("HF_LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
|
||||
local_present = (cache_root / repo_id / "meta" / "info.json").is_file()
|
||||
if not local_present:
|
||||
raise RuntimeError(
|
||||
f"Dataset '{repo_id}' is neither on the Hub nor in the local cache "
|
||||
f"({cache_root}). Record or download it first."
|
||||
)
|
||||
|
||||
print(f"[dataset] '{repo_id}' is local-only; pushing to a PRIVATE Hub repo...")
|
||||
# Lazy import: LeRobotDataset pulls in heavy dataset deps; defer until actually needed.
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
LeRobotDataset(repo_id).push_to_hub(private=True, tags=tags)
|
||||
print(f"[dataset] '{repo_id}' uploaded (private). The job will download it by repo_id.")
|
||||
@@ -0,0 +1,332 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Run a lerobot training on HF Jobs (HuggingFace GPUs).
|
||||
|
||||
Ported and simplified from lelab's runners/hf_cloud.py: no UI log queue, no
|
||||
registry — just submit and stream to stdout.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import datetime as dt
|
||||
import io
|
||||
import json
|
||||
import netrc
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import tempfile
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import get_token
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
_SLUG_RE = re.compile(r"[^a-zA-Z0-9._-]+")
|
||||
|
||||
_TERMINAL_STAGES = {"COMPLETED", "CANCELED", "ERROR", "DELETED"}
|
||||
|
||||
# Always attached to remote jobs and pushed datasets so LeRobot-originated work
|
||||
# is identifiable on the Hub; callers (e.g. LeLab) add their own via --job.tags.
|
||||
LEROBOT_TAG = "lerobot"
|
||||
|
||||
|
||||
def resolve_job_tags(extra: list[str] | None) -> list[str]:
|
||||
"""Return the tag list for a run: the lerobot tag plus any extras, deduped, order-stable."""
|
||||
tags = [LEROBOT_TAG, *(extra or [])]
|
||||
seen: set[str] = set()
|
||||
return [t for t in tags if not (t in seen or seen.add(t))]
|
||||
|
||||
|
||||
def resolve_wandb_api_key() -> str | None:
|
||||
"""Host's wandb key for forwarding to the job: $WANDB_API_KEY, else ~/.netrc."""
|
||||
key = os.environ.get("WANDB_API_KEY")
|
||||
if key:
|
||||
return key
|
||||
try:
|
||||
rc = netrc.netrc()
|
||||
except (FileNotFoundError, netrc.NetrcParseError, OSError):
|
||||
return None
|
||||
auth = rc.authenticators("api.wandb.ai")
|
||||
if auth is None:
|
||||
return None
|
||||
_login, _account, password = auth
|
||||
return password or None
|
||||
|
||||
|
||||
def build_repo_id(username: str, job_name: str, now: dt.datetime) -> str:
|
||||
"""Generate the model repo id for a remote run: <user>/<job_name>_<timestamp>."""
|
||||
slug = _SLUG_RE.sub("-", job_name).strip("-") or "train"
|
||||
stamp = now.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
return f"{username}/{slug}_{stamp}"
|
||||
|
||||
|
||||
def build_remote_config_file(cfg, repo_id: str, dest: Path, tags: list[str] | None = None) -> Path:
|
||||
"""Write a train_config.json for the pod, with remote overrides applied.
|
||||
|
||||
The pod runs `lerobot-train --config_path=<dest>` and downloads the dataset
|
||||
by repo_id into its own cache. Client-only fields are stripped so the config
|
||||
is accepted by the trainer image: `job` (pure client orchestration) is always
|
||||
removed, and `save_checkpoint_to_hub` is removed unless explicitly enabled —
|
||||
older lerobot images reject unknown keys, so the default keeps the config
|
||||
compatible with the released `lerobot-gpu` image. `tags` are merged into
|
||||
policy.tags so the trained model the pod pushes carries them too.
|
||||
"""
|
||||
remote = copy.deepcopy(cfg)
|
||||
remote.policy.push_to_hub = True
|
||||
remote.policy.repo_id = repo_id
|
||||
# Don't pin the client's resolved device (e.g. "mps"); let the pod auto-detect its GPU.
|
||||
remote.policy.device = None
|
||||
# Drop any host-local dataset root; the pod resolves the dataset by repo_id.
|
||||
remote.dataset.root = None
|
||||
if tags:
|
||||
existing = list(remote.policy.tags or [])
|
||||
remote.policy.tags = existing + [t for t in tags if t not in existing]
|
||||
|
||||
# Round-trip through draccus to get the canonical, pod-parseable layout, then
|
||||
# drop the keys the released trainer image doesn't know about.
|
||||
buf = io.StringIO()
|
||||
with draccus.config_type("json"):
|
||||
draccus.dump(remote, buf, indent=4)
|
||||
data = json.loads(buf.getvalue())
|
||||
data.pop("job", None)
|
||||
if not remote.save_checkpoint_to_hub:
|
||||
data.pop("save_checkpoint_to_hub", None)
|
||||
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_text(json.dumps(data, indent=4))
|
||||
return dest
|
||||
|
||||
|
||||
def _stage_config_on_hub(cfg, repo_id: str, token: str, tags: list[str] | None = None) -> str:
|
||||
"""Upload train_config.json to the model repo and return the repo_id for --config_path."""
|
||||
from huggingface_hub import create_repo, upload_file
|
||||
|
||||
create_repo(repo_id, repo_type="model", private=True, exist_ok=True, token=token)
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
config_path = build_remote_config_file(cfg, repo_id, Path(tmp) / "train_config.json", tags=tags)
|
||||
upload_file(
|
||||
path_or_fileobj=config_path,
|
||||
path_in_repo="train_config.json",
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
token=token,
|
||||
)
|
||||
return repo_id
|
||||
|
||||
|
||||
def _tail_logs(
|
||||
job_id: str,
|
||||
done: threading.Event,
|
||||
success_marker: str | None = None,
|
||||
success_event: threading.Event | None = None,
|
||||
) -> None:
|
||||
"""Stream job logs to stdout, reconnecting on dropped streams until done is set.
|
||||
|
||||
Each reconnect re-fetches the full buffered log, so we track how many lines
|
||||
were already printed and skip them — otherwise a fast-failing job's traceback
|
||||
gets reprinted on every reconnect.
|
||||
|
||||
When `success_marker` appears in a line, set `success_event` and `done` so the
|
||||
caller can finish as soon as the trained model lands on the Hub, rather than
|
||||
waiting out the platform's post-run finalization (which can add ~30s).
|
||||
"""
|
||||
from huggingface_hub import fetch_job_logs
|
||||
|
||||
printed = 0
|
||||
while not done.is_set():
|
||||
try:
|
||||
seen = 0
|
||||
for line in fetch_job_logs(job_id=job_id, follow=True):
|
||||
seen += 1
|
||||
if seen <= printed:
|
||||
continue # already shown on a previous connection
|
||||
printed = seen
|
||||
# fetch_job_logs yields SSE data without trailing newlines, so add one
|
||||
# per entry — otherwise all log lines concatenate onto a single line.
|
||||
print(line.rstrip("\n"), flush=True)
|
||||
if success_marker and success_event is not None and success_marker in line:
|
||||
success_event.set()
|
||||
done.set()
|
||||
return
|
||||
if done.is_set():
|
||||
return
|
||||
# Stream closed cleanly. Wait a moment so the status poller can mark
|
||||
# the job terminal before we reconnect (avoids re-tailing the buffer).
|
||||
if done.wait(3):
|
||||
return
|
||||
except Exception:
|
||||
if done.wait(2):
|
||||
return
|
||||
|
||||
|
||||
def _poll_until_done(
|
||||
job_id: str,
|
||||
done: threading.Event,
|
||||
poll_interval: float = 5.0,
|
||||
status_holder: dict | None = None,
|
||||
max_failures: int = 6,
|
||||
) -> str | None:
|
||||
"""Poll inspect_job until a terminal stage or until `done` is set.
|
||||
|
||||
Returns the terminal stage string, or None if `done` was set first (detach)
|
||||
or after `max_failures` consecutive inspect_job errors. When a terminal stage
|
||||
is reached and `status_holder` is given, records `status_holder["message"]`
|
||||
(the platform's status message, e.g. "Job timeout").
|
||||
"""
|
||||
from huggingface_hub import inspect_job
|
||||
|
||||
failures = 0
|
||||
while not done.is_set():
|
||||
try:
|
||||
info = inspect_job(job_id=job_id)
|
||||
failures = 0
|
||||
stage = info.status.stage.value
|
||||
if stage in _TERMINAL_STAGES:
|
||||
if status_holder is not None:
|
||||
status_holder["message"] = getattr(info.status, "message", None)
|
||||
done.set()
|
||||
return stage
|
||||
except Exception:
|
||||
failures += 1
|
||||
if failures >= max_failures:
|
||||
done.set()
|
||||
return None
|
||||
done.wait(poll_interval)
|
||||
return None
|
||||
|
||||
|
||||
def submit_to_hf(cfg: TrainPipelineConfig) -> None:
|
||||
"""Submit a training job to HF Jobs infrastructure.
|
||||
|
||||
Validates cfg, resolves credentials, stages the config on the Hub, submits
|
||||
the job, then either tails logs until completion or detaches immediately.
|
||||
Ctrl-C detaches without cancelling the remote job.
|
||||
"""
|
||||
from huggingface_hub import HfApi, run_job
|
||||
|
||||
from lerobot.jobs.dataset import ensure_dataset_available
|
||||
|
||||
token = get_token()
|
||||
if not token:
|
||||
raise RuntimeError("Not logged in to Hugging Face. Run `hf auth login` first.")
|
||||
|
||||
api = HfApi(token=token)
|
||||
user_info = api.whoami(token=token)
|
||||
username = user_info["name"]
|
||||
|
||||
now = dt.datetime.now()
|
||||
if cfg.policy is not None:
|
||||
base_name = cfg.job_name or cfg.policy.type
|
||||
repo_id = cfg.policy.repo_id or build_repo_id(username, base_name, now)
|
||||
cfg.policy.repo_id = repo_id
|
||||
cfg.policy.push_to_hub = True
|
||||
else:
|
||||
# Path-based policy is resolved inside validate(); fall back to a generic slug.
|
||||
repo_id = build_repo_id(username, cfg.job_name or "train", now)
|
||||
|
||||
cfg.validate()
|
||||
|
||||
secrets: dict[str, str] = {"HF_TOKEN": token}
|
||||
if cfg.wandb.enable:
|
||||
wandb_key = resolve_wandb_api_key()
|
||||
if wandb_key is None:
|
||||
raise ValueError(
|
||||
"wandb is enabled but no WANDB_API_KEY found. "
|
||||
"Set it via `export WANDB_API_KEY=...` or add it to ~/.netrc."
|
||||
)
|
||||
secrets["WANDB_API_KEY"] = wandb_key
|
||||
|
||||
tags = resolve_job_tags(cfg.job.tags)
|
||||
ensure_dataset_available(cfg.dataset.repo_id, api=api, tags=tags)
|
||||
|
||||
config_repo_id = _stage_config_on_hub(cfg, repo_id, token, tags=tags)
|
||||
command = ["lerobot-train", f"--config_path={config_repo_id}"]
|
||||
|
||||
print(f"Submitting job to HF Jobs (flavor={cfg.job.target}, image={cfg.job.image}) ...")
|
||||
job_info = run_job(
|
||||
image=cfg.job.image,
|
||||
command=command,
|
||||
flavor=cfg.job.target,
|
||||
secrets=secrets,
|
||||
timeout=cfg.job.timeout,
|
||||
# HF Jobs labels are key/value; expose each tag as a queryable label.
|
||||
labels=dict.fromkeys(tags, "true"),
|
||||
)
|
||||
job_id = job_info.id
|
||||
job_url = getattr(job_info, "url", None)
|
||||
print(f"Job submitted: {job_id}")
|
||||
if job_url:
|
||||
print(f" Job page: {job_url}")
|
||||
print(f" Model repo: https://huggingface.co/{repo_id}")
|
||||
print(f" Monitor: hf jobs logs {job_id}")
|
||||
print(f" Cancel: hf jobs cancel {job_id}")
|
||||
|
||||
if cfg.job.detach:
|
||||
return
|
||||
|
||||
done = threading.Event()
|
||||
detached = threading.Event()
|
||||
pushed_ok = threading.Event()
|
||||
stage_holder: dict[str, str | None] = {}
|
||||
|
||||
def _poll() -> None:
|
||||
stage_holder["stage"] = _poll_until_done(job_id, done, status_holder=stage_holder)
|
||||
|
||||
poll_thread = threading.Thread(target=_poll, daemon=True)
|
||||
poll_thread.start()
|
||||
# Finish as soon as the model is pushed, rather than waiting out the platform's
|
||||
# post-run finalization before the job stage flips to COMPLETED.
|
||||
success_marker = f"Model pushed to https://huggingface.co/{repo_id}"
|
||||
log_thread = threading.Thread(
|
||||
target=_tail_logs, args=(job_id, done, success_marker, pushed_ok), daemon=True
|
||||
)
|
||||
log_thread.start()
|
||||
|
||||
def _detach(sig, frame):
|
||||
detached.set()
|
||||
done.set()
|
||||
print("\nDetached. Job is still running.")
|
||||
print(f" Monitor: hf jobs logs {job_id}")
|
||||
print(f" Cancel: hf jobs cancel {job_id}")
|
||||
|
||||
original_sigint = signal.getsignal(signal.SIGINT)
|
||||
signal.signal(signal.SIGINT, _detach)
|
||||
try:
|
||||
# Timeout-based join so SIGINT is delivered to the main thread promptly.
|
||||
while poll_thread.is_alive():
|
||||
poll_thread.join(timeout=0.5)
|
||||
log_thread.join(timeout=5)
|
||||
finally:
|
||||
signal.signal(signal.SIGINT, original_sigint)
|
||||
|
||||
if detached.is_set():
|
||||
return
|
||||
|
||||
if pushed_ok.is_set():
|
||||
print(f"\nTraining complete — model pushed to https://huggingface.co/{repo_id}")
|
||||
return
|
||||
|
||||
stage = stage_holder.get("stage")
|
||||
if stage != "COMPLETED":
|
||||
message = stage_holder.get("message")
|
||||
detail = f" ({message})" if message else ""
|
||||
raise RuntimeError(
|
||||
f"Job {job_id} ended with stage={stage}{detail}. Check logs: hf jobs logs {job_id}"
|
||||
)
|
||||
@@ -41,6 +41,7 @@ from lerobot.common.train_utils import (
|
||||
load_training_batch_size,
|
||||
load_training_num_processes,
|
||||
load_training_state,
|
||||
push_checkpoint_to_hub,
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
@@ -188,6 +189,11 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
||||
accelerator: Optional Accelerator instance. If None, one will be created automatically.
|
||||
"""
|
||||
if cfg.job.is_remote:
|
||||
from lerobot.jobs import submit_to_hf
|
||||
|
||||
return submit_to_hf(cfg)
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("accelerate", extra="training")
|
||||
@@ -655,6 +661,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
optim_state_dict=optim_state_dict,
|
||||
)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if cfg.save_checkpoint_to_hub:
|
||||
push_checkpoint_to_hub(
|
||||
checkpoint_dir,
|
||||
cfg.policy.repo_id,
|
||||
private=cfg.policy.private,
|
||||
)
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
|
||||
@@ -735,8 +747,29 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def _remote_target_in_argv() -> bool:
|
||||
"""True when the CLI requests a remote HF Jobs run (--job.target=<non-local>)."""
|
||||
import sys
|
||||
|
||||
from lerobot.configs.default import JobConfig
|
||||
|
||||
target = None
|
||||
args = sys.argv[1:]
|
||||
for i, tok in enumerate(args):
|
||||
if tok == "--job.target" and i + 1 < len(args):
|
||||
target = args[i + 1]
|
||||
elif tok.startswith("--job.target="):
|
||||
target = tok.split("=", 1)[1]
|
||||
return JobConfig.is_remote_target(target)
|
||||
|
||||
|
||||
def main():
|
||||
register_third_party_plugins()
|
||||
if _remote_target_in_argv():
|
||||
# The policy device is resolved on the remote pod, not here, so silence the
|
||||
# client-side "Device '...' is not available" warning PreTrainedConfig emits
|
||||
# while parsing the config (it fires before train() can dispatch remotely).
|
||||
logging.getLogger("lerobot.configs.policies").setLevel(logging.ERROR)
|
||||
train()
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,78 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from huggingface_hub.errors import RepositoryNotFoundError
|
||||
|
||||
from lerobot.jobs.dataset import ensure_dataset_available
|
||||
|
||||
|
||||
def _repo_not_found() -> RepositoryNotFoundError:
|
||||
req = httpx.Request("GET", "https://huggingface.co/datasets/test")
|
||||
resp = httpx.Response(404, request=req)
|
||||
return RepositoryNotFoundError("nope", response=resp)
|
||||
|
||||
|
||||
def _api_with_dataset(exists: bool):
|
||||
api = MagicMock()
|
||||
if exists:
|
||||
api.dataset_info.return_value = object()
|
||||
else:
|
||||
api.dataset_info.side_effect = _repo_not_found()
|
||||
return api
|
||||
|
||||
|
||||
def _make_local_cache(tmp_path, repo_id: str) -> None:
|
||||
"""Create the minimal local-cache layout that ensure_dataset_available checks."""
|
||||
info = tmp_path / repo_id / "meta" / "info.json"
|
||||
info.parent.mkdir(parents=True)
|
||||
info.write_text("{}")
|
||||
|
||||
|
||||
# Branch 1: dataset already on Hub → no push, no error (pod downloads by repo_id).
|
||||
def test_dataset_already_on_hub_is_noop():
|
||||
api = _api_with_dataset(True)
|
||||
assert ensure_dataset_available("user/ds", api=api) is None
|
||||
api.dataset_info.assert_called_once_with("user/ds")
|
||||
|
||||
|
||||
# Branch 2: not on Hub but present locally → always push privately.
|
||||
def test_dataset_local_only_uploads_privately(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HF_LEROBOT_HOME", str(tmp_path))
|
||||
_make_local_cache(tmp_path, "user/ds")
|
||||
|
||||
api = _api_with_dataset(False)
|
||||
mock_ds_cls = MagicMock()
|
||||
fake_datasets_module = MagicMock()
|
||||
fake_datasets_module.LeRobotDataset = mock_ds_cls
|
||||
monkeypatch.setitem(sys.modules, "lerobot.datasets", fake_datasets_module)
|
||||
|
||||
assert ensure_dataset_available("user/ds", api=api, tags=["lerobot", "lelab"]) is None
|
||||
|
||||
mock_ds_cls.assert_called_once_with("user/ds")
|
||||
mock_ds_cls.return_value.push_to_hub.assert_called_once_with(private=True, tags=["lerobot", "lelab"])
|
||||
|
||||
|
||||
# Branch 3: not on Hub, NOT in local cache → RuntimeError "neither".
|
||||
def test_dataset_neither_on_hub_nor_local_raises(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HF_LEROBOT_HOME", str(tmp_path))
|
||||
# tmp_path is empty — no local cache.
|
||||
|
||||
api = _api_with_dataset(False)
|
||||
with pytest.raises(RuntimeError, match="neither"):
|
||||
ensure_dataset_available("user/ds", api=api)
|
||||
@@ -0,0 +1,426 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime as dt
|
||||
import json
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
|
||||
import draccus
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.jobs.hf import (
|
||||
_poll_until_done,
|
||||
build_remote_config_file,
|
||||
build_repo_id,
|
||||
resolve_job_tags,
|
||||
resolve_wandb_api_key,
|
||||
submit_to_hf,
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_job_tags_always_includes_lerobot_and_dedups():
|
||||
assert resolve_job_tags(None) == ["lerobot"]
|
||||
assert resolve_job_tags([]) == ["lerobot"]
|
||||
assert resolve_job_tags(["lelab"]) == ["lerobot", "lelab"]
|
||||
# lerobot isn't duplicated if passed explicitly; order is stable.
|
||||
assert resolve_job_tags(["lelab", "lerobot", "lelab"]) == ["lerobot", "lelab"]
|
||||
|
||||
|
||||
def _fake_inspect(stage_value):
|
||||
return lambda job_id: SimpleNamespace(status=SimpleNamespace(stage=SimpleNamespace(value=stage_value)))
|
||||
|
||||
|
||||
def test_poll_until_done_returns_terminal_stage(monkeypatch):
|
||||
monkeypatch.setattr("huggingface_hub.inspect_job", _fake_inspect("COMPLETED"))
|
||||
done = threading.Event()
|
||||
assert _poll_until_done("j", done, poll_interval=0.01) == "COMPLETED"
|
||||
assert done.is_set()
|
||||
|
||||
|
||||
def test_poll_until_done_exits_when_done_already_set(monkeypatch):
|
||||
# Non-terminal forever; with done pre-set the loop must not block and returns None.
|
||||
monkeypatch.setattr("huggingface_hub.inspect_job", _fake_inspect("RUNNING"))
|
||||
done = threading.Event()
|
||||
done.set()
|
||||
assert _poll_until_done("j", done, poll_interval=0.01) is None
|
||||
|
||||
|
||||
def test_poll_until_done_gives_up_after_repeated_failures(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"huggingface_hub.inspect_job", lambda job_id: (_ for _ in ()).throw(RuntimeError("boom"))
|
||||
)
|
||||
done = threading.Event()
|
||||
result = _poll_until_done("j", done, poll_interval=0.001, max_failures=3)
|
||||
assert result is None
|
||||
assert done.is_set()
|
||||
|
||||
|
||||
def test_resolve_wandb_key_from_env(monkeypatch):
|
||||
monkeypatch.setenv("WANDB_API_KEY", "abc123")
|
||||
assert resolve_wandb_api_key() == "abc123"
|
||||
|
||||
|
||||
def test_resolve_wandb_key_missing(monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("WANDB_API_KEY", raising=False)
|
||||
monkeypatch.setenv("HOME", str(tmp_path)) # no ~/.netrc here
|
||||
monkeypatch.setattr("netrc.netrc", lambda *a, **k: (_ for _ in ()).throw(FileNotFoundError()))
|
||||
assert resolve_wandb_api_key() is None
|
||||
|
||||
|
||||
def test_resolve_wandb_key_from_netrc(monkeypatch):
|
||||
# No env var → fall back to the wandb credentials in ~/.netrc.
|
||||
monkeypatch.delenv("WANDB_API_KEY", raising=False)
|
||||
|
||||
class _FakeNetrc:
|
||||
def authenticators(self, host):
|
||||
assert host == "api.wandb.ai"
|
||||
return ("login", "account", "netrc-secret")
|
||||
|
||||
monkeypatch.setattr("netrc.netrc", lambda *a, **k: _FakeNetrc())
|
||||
assert resolve_wandb_api_key() == "netrc-secret"
|
||||
|
||||
|
||||
def test_resolve_wandb_key_netrc_without_wandb_entry(monkeypatch):
|
||||
# ~/.netrc exists but has no api.wandb.ai entry → None.
|
||||
monkeypatch.delenv("WANDB_API_KEY", raising=False)
|
||||
|
||||
class _FakeNetrc:
|
||||
def authenticators(self, host):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("netrc.netrc", lambda *a, **k: _FakeNetrc())
|
||||
assert resolve_wandb_api_key() is None
|
||||
|
||||
|
||||
def test_build_repo_id_sanitizes_and_timestamps():
|
||||
now = dt.datetime(2026, 6, 19, 10, 22, 3)
|
||||
assert build_repo_id("alice", "act", now) == "alice/act_2026-06-19_10-22-03"
|
||||
# Runs of illegal characters collapse to a single dash; edges are trimmed.
|
||||
assert build_repo_id("alice", "my cool/run!!", now) == "alice/my-cool-run_2026-06-19_10-22-03"
|
||||
# A name with nothing usable falls back to "train".
|
||||
assert build_repo_id("alice", "///", now) == "alice/train_2026-06-19_10-22-03"
|
||||
|
||||
|
||||
def _minimal_cfg():
|
||||
return draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
|
||||
|
||||
def test_build_remote_config_applies_overrides(tmp_path):
|
||||
cfg = _minimal_cfg()
|
||||
dest = tmp_path / "train_config.json"
|
||||
out = build_remote_config_file(cfg, "u/run", dest)
|
||||
assert out == dest
|
||||
data = json.loads(dest.read_text())
|
||||
# `job` is client-only orchestration and must be stripped for the pod.
|
||||
assert "job" not in data
|
||||
# save_checkpoint_to_hub defaults off → omitted so older images accept the config.
|
||||
assert "save_checkpoint_to_hub" not in data
|
||||
assert data["policy"]["push_to_hub"] is True
|
||||
assert data["policy"]["repo_id"] == "u/run"
|
||||
assert data["policy"]["device"] is None # pod auto-detects its GPU
|
||||
assert data["dataset"]["root"] is None # pod resolves the dataset by repo_id
|
||||
# the caller's cfg must be left untouched (function works on a deep copy)
|
||||
assert cfg.job.target == "a10g-small"
|
||||
assert cfg.save_checkpoint_to_hub is False
|
||||
|
||||
|
||||
def test_build_remote_config_includes_checkpoint_flag_when_enabled(tmp_path):
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=[
|
||||
"--dataset.repo_id",
|
||||
"u/d",
|
||||
"--policy.type",
|
||||
"act",
|
||||
"--job.target",
|
||||
"a10g-small",
|
||||
"--save_checkpoint_to_hub",
|
||||
"true",
|
||||
],
|
||||
)
|
||||
dest = tmp_path / "train_config.json"
|
||||
build_remote_config_file(cfg, "u/run", dest)
|
||||
data = json.loads(dest.read_text())
|
||||
# explicitly enabled → kept in the config (requires a matching trainer image).
|
||||
assert data["save_checkpoint_to_hub"] is True
|
||||
assert "job" not in data
|
||||
|
||||
|
||||
def test_build_remote_config_merges_tags_into_policy(tmp_path):
|
||||
cfg = _minimal_cfg()
|
||||
dest = tmp_path / "train_config.json"
|
||||
build_remote_config_file(cfg, "u/run", dest, tags=["lerobot", "lelab"])
|
||||
data = json.loads(dest.read_text())
|
||||
# tags propagate to the model the pod pushes.
|
||||
assert data["policy"]["tags"] == ["lerobot", "lelab"]
|
||||
|
||||
|
||||
def test_build_remote_config_merges_tags_without_duplicating(tmp_path):
|
||||
cfg = _minimal_cfg()
|
||||
cfg.policy.tags = ["existing", "lerobot"]
|
||||
dest = tmp_path / "train_config.json"
|
||||
build_remote_config_file(cfg, "u/run", dest, tags=["lerobot", "lelab"])
|
||||
data = json.loads(dest.read_text())
|
||||
# pre-existing policy tags are kept; only genuinely-new tags are appended (no dup "lerobot").
|
||||
assert data["policy"]["tags"] == ["existing", "lerobot", "lelab"]
|
||||
|
||||
|
||||
def test_submit_requires_login(monkeypatch):
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: None)
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="hf auth login"):
|
||||
submit_to_hf(cfg)
|
||||
|
||||
|
||||
def test_submit_passes_validation_and_submits(monkeypatch):
|
||||
"""Regression: repo_id must be set BEFORE cfg.validate() or validation raises."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
# Patch get_token
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
# Patch HfApi so whoami returns alice
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
|
||||
|
||||
# ensure_dataset_available returns None; patch it out so no Hub access happens
|
||||
# (imported inside submit_to_hf via `from lerobot.jobs.dataset import ensure_dataset_available`).
|
||||
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
|
||||
|
||||
# Patch _stage_config_on_hub to skip network
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf._stage_config_on_hub",
|
||||
lambda cfg, repo_id, token, tags=None: repo_id,
|
||||
)
|
||||
|
||||
# Patch run_job to return a fake job
|
||||
fake_job = MagicMock()
|
||||
fake_job.id = "job-123"
|
||||
run_job_calls = []
|
||||
|
||||
def fake_run_job(**kwargs):
|
||||
run_job_calls.append(kwargs)
|
||||
return fake_job
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "run_job", fake_run_job)
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=[
|
||||
"--dataset.repo_id",
|
||||
"u/d",
|
||||
"--policy.type",
|
||||
"act",
|
||||
"--job.target",
|
||||
"a10g-small",
|
||||
"--job.detach",
|
||||
"true",
|
||||
],
|
||||
)
|
||||
|
||||
# Must NOT raise (pre-fix this raised ValueError about missing repo_id)
|
||||
submit_to_hf(cfg)
|
||||
|
||||
assert len(run_job_calls) == 1, "run_job should have been called exactly once"
|
||||
assert cfg.policy.repo_id is not None
|
||||
assert cfg.policy.repo_id.startswith("alice/")
|
||||
call = run_job_calls[0]
|
||||
# The pod runs `lerobot-train --config_path=<staged repo>` on the requested flavor/image.
|
||||
assert call["command"][0] == "lerobot-train"
|
||||
assert call["command"][1].startswith("--config_path=")
|
||||
assert call["flavor"] == "a10g-small"
|
||||
assert call["image"] == "huggingface/lerobot-gpu:latest"
|
||||
# The Hub token is forwarded so the pod can pull the (possibly private) dataset.
|
||||
assert call["secrets"]["HF_TOKEN"] == "tok"
|
||||
# Every job carries the lerobot tag as a queryable label.
|
||||
assert call["labels"].get("lerobot") == "true"
|
||||
|
||||
|
||||
@pytest.mark.timeout(15)
|
||||
def test_submit_returns_when_job_completes(monkeypatch):
|
||||
"""Non-detach path must RETURN (not hang) once the job reaches a terminal stage."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
|
||||
)
|
||||
monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
|
||||
# Job is already COMPLETED on the first poll.
|
||||
monkeypatch.setattr(
|
||||
"huggingface_hub.inspect_job",
|
||||
lambda job_id: SimpleNamespace(
|
||||
status=SimpleNamespace(stage=SimpleNamespace(value="COMPLETED"), message=None)
|
||||
),
|
||||
)
|
||||
# Log stream ends immediately.
|
||||
monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter(()))
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
# Runs in the pytest main thread (signal handler install requires it); the
|
||||
# @timeout marker fails the test instead of hanging if it regresses.
|
||||
submit_to_hf(cfg)
|
||||
|
||||
|
||||
@pytest.mark.timeout(15)
|
||||
def test_submit_returns_on_model_pushed_marker(monkeypatch):
|
||||
"""Finish when the model-pushed log appears, even if the job stage never flips."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
|
||||
)
|
||||
monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
|
||||
# Job stays RUNNING forever — only the log marker can end the command.
|
||||
monkeypatch.setattr(
|
||||
"huggingface_hub.inspect_job",
|
||||
lambda job_id: SimpleNamespace(
|
||||
status=SimpleNamespace(stage=SimpleNamespace(value="RUNNING"), message=None)
|
||||
),
|
||||
)
|
||||
pushed_line = "INFO Model pushed to https://huggingface.co/alice/myrun"
|
||||
monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter([pushed_line]))
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=[
|
||||
"--dataset.repo_id",
|
||||
"u/d",
|
||||
"--policy.type",
|
||||
"act",
|
||||
"--policy.repo_id",
|
||||
"alice/myrun",
|
||||
"--job.target",
|
||||
"a10g-small",
|
||||
],
|
||||
)
|
||||
# Must return via the model-pushed marker despite the perpetual RUNNING stage.
|
||||
submit_to_hf(cfg)
|
||||
|
||||
|
||||
def test_submit_raises_when_wandb_enabled_without_key(monkeypatch):
|
||||
"""wandb.enable with no key reachable anywhere fails fast, before submitting."""
|
||||
import huggingface_hub
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.hf.resolve_wandb_api_key", lambda: None)
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=[
|
||||
"--dataset.repo_id",
|
||||
"u/d",
|
||||
"--policy.type",
|
||||
"act",
|
||||
"--job.target",
|
||||
"a10g-small",
|
||||
"--wandb.enable",
|
||||
"true",
|
||||
],
|
||||
)
|
||||
with pytest.raises(ValueError, match="WANDB_API_KEY"):
|
||||
submit_to_hf(cfg)
|
||||
|
||||
|
||||
@pytest.mark.timeout(15)
|
||||
def test_submit_raises_when_job_ends_in_error(monkeypatch):
|
||||
"""A terminal non-COMPLETED stage with no model-pushed marker must raise with the status."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
|
||||
)
|
||||
monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
|
||||
# Job fails: a terminal ERROR stage carrying the platform's status message.
|
||||
monkeypatch.setattr(
|
||||
"huggingface_hub.inspect_job",
|
||||
lambda job_id: SimpleNamespace(
|
||||
status=SimpleNamespace(stage=SimpleNamespace(value="ERROR"), message="Job timeout")
|
||||
),
|
||||
)
|
||||
# Logs end without the model-pushed marker.
|
||||
monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter(()))
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
with pytest.raises(RuntimeError, match=r"stage=ERROR \(Job timeout\)"):
|
||||
submit_to_hf(cfg)
|
||||
@@ -0,0 +1,60 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
|
||||
import draccus
|
||||
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.policies.act.configuration_act import ACTConfig # noqa: F401 (registers --policy.type act)
|
||||
from lerobot.scripts.lerobot_train import _remote_target_in_argv, train
|
||||
|
||||
|
||||
def _set_argv(monkeypatch, *args):
|
||||
monkeypatch.setattr(sys, "argv", ["lerobot-train", *args])
|
||||
|
||||
|
||||
def test_remote_target_detected_space_separated(monkeypatch):
|
||||
_set_argv(monkeypatch, "--policy.type", "act", "--job.target", "a10g-small")
|
||||
assert _remote_target_in_argv() is True
|
||||
|
||||
|
||||
def test_remote_target_detected_equals(monkeypatch):
|
||||
_set_argv(monkeypatch, "--job.target=t4-small")
|
||||
assert _remote_target_in_argv() is True
|
||||
|
||||
|
||||
def test_local_string_is_not_remote(monkeypatch):
|
||||
_set_argv(monkeypatch, "--job.target", "local")
|
||||
assert _remote_target_in_argv() is False
|
||||
|
||||
|
||||
def test_no_target_is_not_remote(monkeypatch):
|
||||
_set_argv(monkeypatch, "--policy.type", "act")
|
||||
assert _remote_target_in_argv() is False
|
||||
|
||||
|
||||
def test_train_dispatches_to_submit_when_remote(monkeypatch):
|
||||
"""A remote --job.target short-circuits train() to the HF Jobs submitter."""
|
||||
import lerobot.jobs
|
||||
|
||||
captured = []
|
||||
monkeypatch.setattr(lerobot.jobs, "submit_to_hf", lambda cfg: captured.append(cfg) or "submitted")
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
# Returns the submitter's result and never enters the local training path.
|
||||
assert train(cfg) == "submitted"
|
||||
assert captured == [cfg]
|
||||
Reference in New Issue
Block a user