mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-25 20:27:05 +00:00
fix(jobs): address claude review findings on remote training
Resolve the claude[bot] review on #3856: - Reject reward-model training under --job.target with a clear error instead of crashing on a None policy inside build_remote_config_file. - Support --policy.path remote runs: validate() no longer requires repo_id for remote runs (it is auto-generated in submit_to_hf), and repo_id/push_to_hub are now set after validate() resolves the policy. - Narrow the bare `except Exception` in _tail_logs/_poll_until_done to (OSError, httpx.HTTPError) so programming errors surface instead of being silently retried or counted as job failures. - Install the SIGINT detach handler only on the main thread. - Generate model repo timestamps in UTC.
This commit is contained in:
@@ -215,7 +215,14 @@ class TrainPipelineConfig(HubMixin):
|
||||
self.optimizer = active_cfg.get_optimizer_preset()
|
||||
self.scheduler = active_cfg.get_scheduler_preset()
|
||||
|
||||
if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id:
|
||||
# Remote runs auto-generate the repo_id in submit_to_hf (the policy may only be
|
||||
# resolved here, from --policy.path), so don't demand it up front for them.
|
||||
if (
|
||||
hasattr(active_cfg, "push_to_hub")
|
||||
and active_cfg.push_to_hub
|
||||
and not active_cfg.repo_id
|
||||
and not self.job.is_remote
|
||||
):
|
||||
raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.")
|
||||
|
||||
if self.save_checkpoint_to_hub and not (self.policy is not None and self.policy.repo_id):
|
||||
|
||||
+34
-15
@@ -31,6 +31,7 @@ import threading
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
from huggingface_hub import (
|
||||
HfApi,
|
||||
create_repo,
|
||||
@@ -48,6 +49,12 @@ _SLUG_RE = re.compile(r"[^a-zA-Z0-9._-]+")
|
||||
|
||||
_TERMINAL_STAGES = {"COMPLETED", "CANCELED", "ERROR", "DELETED"}
|
||||
|
||||
# huggingface_hub 1.x runs on httpx: transient HTTP/transport failures surface as
|
||||
# httpx.HTTPError and socket-level errors as OSError. Catching only these keeps real
|
||||
# bugs (TypeError, AttributeError, ...) from being silently retried or counted as
|
||||
# job failures.
|
||||
_TRANSIENT_NET_ERRORS = (OSError, httpx.HTTPError)
|
||||
|
||||
# Always attached to remote jobs and pushed datasets so LeRobot-originated work
|
||||
# is identifiable on the Hub; callers (e.g. LeLab) add their own via --job.tags.
|
||||
LEROBOT_TAG = "lerobot"
|
||||
@@ -170,7 +177,7 @@ def _tail_logs(
|
||||
# the job terminal before we reconnect (avoids re-tailing the buffer).
|
||||
if done.wait(3):
|
||||
return
|
||||
except Exception:
|
||||
except _TRANSIENT_NET_ERRORS:
|
||||
if done.wait(2):
|
||||
return
|
||||
|
||||
@@ -200,7 +207,7 @@ def _poll_until_done(
|
||||
status_holder["message"] = getattr(info.status, "message", None)
|
||||
done.set()
|
||||
return stage
|
||||
except Exception:
|
||||
except _TRANSIENT_NET_ERRORS:
|
||||
failures += 1
|
||||
if failures >= max_failures:
|
||||
done.set()
|
||||
@@ -226,18 +233,24 @@ def submit_to_hf(cfg: TrainPipelineConfig) -> None:
|
||||
user_info = api.whoami(token=token)
|
||||
username = user_info["name"]
|
||||
|
||||
now = dt.datetime.now()
|
||||
if cfg.policy is not None:
|
||||
base_name = cfg.job_name or cfg.policy.type
|
||||
repo_id = cfg.policy.repo_id or build_repo_id(username, base_name, now)
|
||||
cfg.policy.repo_id = repo_id
|
||||
cfg.policy.push_to_hub = True
|
||||
else:
|
||||
# Path-based policy is resolved inside validate(); fall back to a generic slug.
|
||||
repo_id = build_repo_id(username, cfg.job_name or "train", now)
|
||||
|
||||
# validate() resolves a `--policy.path=...` policy into cfg.policy and skips its
|
||||
# repo_id requirement for remote runs (we assign one below), so it's safe to run first.
|
||||
cfg.validate()
|
||||
|
||||
if cfg.is_reward_model_training:
|
||||
raise ValueError(
|
||||
"Remote training via --job.target only supports policy training, not reward models. "
|
||||
"Run reward-model training locally."
|
||||
)
|
||||
|
||||
# Auto-generate the model repo unless the user pinned one. cfg.policy is guaranteed
|
||||
# set here (validate() raises if neither policy nor reward_model is configured, and
|
||||
# reward-model runs are rejected above).
|
||||
now = dt.datetime.now(dt.UTC)
|
||||
repo_id = cfg.policy.repo_id or build_repo_id(username, cfg.job_name or cfg.policy.type, now)
|
||||
cfg.policy.repo_id = repo_id
|
||||
cfg.policy.push_to_hub = True
|
||||
|
||||
secrets: dict[str, str] = {"HF_TOKEN": token}
|
||||
if cfg.wandb.enable:
|
||||
wandb_key = resolve_wandb_api_key()
|
||||
@@ -301,15 +314,21 @@ def submit_to_hf(cfg: TrainPipelineConfig) -> None:
|
||||
print(f" Monitor: hf jobs logs {job_id}")
|
||||
print(f" Cancel: hf jobs cancel {job_id}")
|
||||
|
||||
original_sigint = signal.getsignal(signal.SIGINT)
|
||||
signal.signal(signal.SIGINT, _detach)
|
||||
# signal.signal only works on the main thread; when called from a worker thread
|
||||
# (e.g. an orchestration framework) skip the Ctrl-C-detaches-instead-of-cancels
|
||||
# handler rather than crashing with ValueError.
|
||||
install_sigint = threading.current_thread() is threading.main_thread()
|
||||
original_sigint = signal.getsignal(signal.SIGINT) if install_sigint else None
|
||||
if install_sigint:
|
||||
signal.signal(signal.SIGINT, _detach)
|
||||
try:
|
||||
# Timeout-based join so SIGINT is delivered to the main thread promptly.
|
||||
while poll_thread.is_alive():
|
||||
poll_thread.join(timeout=0.5)
|
||||
log_thread.join(timeout=5)
|
||||
finally:
|
||||
signal.signal(signal.SIGINT, original_sigint)
|
||||
if install_sigint:
|
||||
signal.signal(signal.SIGINT, original_sigint)
|
||||
|
||||
if detached.is_set():
|
||||
return
|
||||
|
||||
+50
-3
@@ -18,6 +18,7 @@ import threading
|
||||
from types import SimpleNamespace
|
||||
|
||||
import draccus
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
@@ -58,9 +59,9 @@ def test_poll_until_done_exits_when_done_already_set(monkeypatch):
|
||||
assert _poll_until_done("j", done, poll_interval=0.01) is None
|
||||
|
||||
|
||||
def test_poll_until_done_gives_up_after_repeated_failures(monkeypatch):
|
||||
def test_poll_until_done_gives_up_after_repeated_network_failures(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf.inspect_job", lambda job_id: (_ for _ in ()).throw(RuntimeError("boom"))
|
||||
"lerobot.jobs.hf.inspect_job", lambda job_id: (_ for _ in ()).throw(httpx.ConnectError("boom"))
|
||||
)
|
||||
done = threading.Event()
|
||||
result = _poll_until_done("j", done, poll_interval=0.001, max_failures=3)
|
||||
@@ -68,6 +69,14 @@ def test_poll_until_done_gives_up_after_repeated_failures(monkeypatch):
|
||||
assert done.is_set()
|
||||
|
||||
|
||||
def test_poll_until_done_propagates_programming_errors(monkeypatch):
|
||||
"""A bug (e.g. TypeError) must surface, not be silently retried as a transient failure."""
|
||||
monkeypatch.setattr("lerobot.jobs.hf.inspect_job", lambda job_id: (_ for _ in ()).throw(TypeError("bug")))
|
||||
done = threading.Event()
|
||||
with pytest.raises(TypeError):
|
||||
_poll_until_done("j", done, poll_interval=0.001, max_failures=3)
|
||||
|
||||
|
||||
def test_resolve_wandb_key_from_env(monkeypatch):
|
||||
monkeypatch.setenv("WANDB_API_KEY", "abc123")
|
||||
assert resolve_wandb_api_key() == "abc123"
|
||||
@@ -121,6 +130,23 @@ def _minimal_cfg():
|
||||
)
|
||||
|
||||
|
||||
def test_validate_skips_repo_id_check_for_remote():
|
||||
"""Remote runs auto-assign repo_id in submit_to_hf, so validate() must not demand it up front."""
|
||||
cfg = _minimal_cfg() # remote target, push_to_hub default True, no explicit repo_id
|
||||
assert cfg.policy.repo_id is None
|
||||
cfg.validate() # must not raise
|
||||
|
||||
|
||||
def test_validate_requires_repo_id_for_local_push():
|
||||
"""Local runs that push to the Hub still need an explicit repo_id."""
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act"],
|
||||
)
|
||||
with pytest.raises(ValueError, match="repo_id"):
|
||||
cfg.validate()
|
||||
|
||||
|
||||
def test_build_remote_config_applies_overrides(tmp_path):
|
||||
cfg = _minimal_cfg()
|
||||
dest = tmp_path / "train_config.json"
|
||||
@@ -192,7 +218,7 @@ def test_submit_requires_login(monkeypatch):
|
||||
|
||||
|
||||
def test_submit_passes_validation_and_submits(monkeypatch):
|
||||
"""Regression: repo_id must be set BEFORE cfg.validate() or validation raises."""
|
||||
"""A type-based policy with no explicit repo_id is auto-assigned one and submitted."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Patch get_token
|
||||
@@ -261,6 +287,27 @@ def test_submit_passes_validation_and_submits(monkeypatch):
|
||||
assert call["labels"].get("lerobot") == "true"
|
||||
|
||||
|
||||
def test_submit_rejects_reward_model_training(monkeypatch):
|
||||
"""Remote training only supports policies; reward-model runs fail fast with a clear error."""
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi)
|
||||
|
||||
cfg = _minimal_cfg()
|
||||
cfg.reward_model = SimpleNamespace(type="reward") # marks this as reward-model training
|
||||
monkeypatch.setattr(cfg, "validate", lambda: None) # skip pretrained-path resolution
|
||||
|
||||
with pytest.raises(ValueError, match="reward model"):
|
||||
submit_to_hf(cfg)
|
||||
|
||||
|
||||
@pytest.mark.timeout(15)
|
||||
def test_submit_returns_when_job_completes(monkeypatch):
|
||||
"""Non-detach path must RETURN (not hang) once the job reaches a terminal stage."""
|
||||
|
||||
Reference in New Issue
Block a user