From ab69bc5f061ccff9ff38742692829c314d609887 Mon Sep 17 00:00:00 2001 From: Nicolas Rabault Date: Thu, 25 Jun 2026 16:11:06 +0200 Subject: [PATCH] fix(jobs): address claude review findings on remote training Resolve the claude[bot] review on #3856: - Reject reward-model training under --job.target with a clear error instead of crashing on a None policy inside build_remote_config_file. - Support --policy.path remote runs: validate() no longer requires repo_id for remote runs (it is auto-generated in submit_to_hf), and repo_id/push_to_hub are now set after validate() resolves the policy. - Narrow the bare `except Exception` in _tail_logs/_poll_until_done to (OSError, httpx.HTTPError) so programming errors surface instead of being silently retried or counted as job failures. - Install the SIGINT detach handler only on the main thread. - Generate model repo timestamps in UTC. --- src/lerobot/configs/train.py | 9 +++++- src/lerobot/jobs/hf.py | 49 +++++++++++++++++++++++---------- tests/jobs/test_hf.py | 53 ++++++++++++++++++++++++++++++++++-- 3 files changed, 92 insertions(+), 19 deletions(-) diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 17707e120..c89d25bca 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -215,7 +215,14 @@ class TrainPipelineConfig(HubMixin): self.optimizer = active_cfg.get_optimizer_preset() self.scheduler = active_cfg.get_scheduler_preset() - if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id: + # Remote runs auto-generate the repo_id in submit_to_hf (the policy may only be + # resolved here, from --policy.path), so don't demand it up front for them. + if ( + hasattr(active_cfg, "push_to_hub") + and active_cfg.push_to_hub + and not active_cfg.repo_id + and not self.job.is_remote + ): raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.") if self.save_checkpoint_to_hub and not (self.policy is not None and self.policy.repo_id): diff --git a/src/lerobot/jobs/hf.py b/src/lerobot/jobs/hf.py index 8b1103504..d7968683a 100644 --- a/src/lerobot/jobs/hf.py +++ b/src/lerobot/jobs/hf.py @@ -31,6 +31,7 @@ import threading from pathlib import Path from typing import TYPE_CHECKING +import httpx from huggingface_hub import ( HfApi, create_repo, @@ -48,6 +49,12 @@ _SLUG_RE = re.compile(r"[^a-zA-Z0-9._-]+") _TERMINAL_STAGES = {"COMPLETED", "CANCELED", "ERROR", "DELETED"} +# huggingface_hub 1.x runs on httpx: transient HTTP/transport failures surface as +# httpx.HTTPError and socket-level errors as OSError. Catching only these keeps real +# bugs (TypeError, AttributeError, ...) from being silently retried or counted as +# job failures. +_TRANSIENT_NET_ERRORS = (OSError, httpx.HTTPError) + # Always attached to remote jobs and pushed datasets so LeRobot-originated work # is identifiable on the Hub; callers (e.g. LeLab) add their own via --job.tags. LEROBOT_TAG = "lerobot" @@ -170,7 +177,7 @@ def _tail_logs( # the job terminal before we reconnect (avoids re-tailing the buffer). if done.wait(3): return - except Exception: + except _TRANSIENT_NET_ERRORS: if done.wait(2): return @@ -200,7 +207,7 @@ def _poll_until_done( status_holder["message"] = getattr(info.status, "message", None) done.set() return stage - except Exception: + except _TRANSIENT_NET_ERRORS: failures += 1 if failures >= max_failures: done.set() @@ -226,18 +233,24 @@ def submit_to_hf(cfg: TrainPipelineConfig) -> None: user_info = api.whoami(token=token) username = user_info["name"] - now = dt.datetime.now() - if cfg.policy is not None: - base_name = cfg.job_name or cfg.policy.type - repo_id = cfg.policy.repo_id or build_repo_id(username, base_name, now) - cfg.policy.repo_id = repo_id - cfg.policy.push_to_hub = True - else: - # Path-based policy is resolved inside validate(); fall back to a generic slug. - repo_id = build_repo_id(username, cfg.job_name or "train", now) - + # validate() resolves a `--policy.path=...` policy into cfg.policy and skips its + # repo_id requirement for remote runs (we assign one below), so it's safe to run first. cfg.validate() + if cfg.is_reward_model_training: + raise ValueError( + "Remote training via --job.target only supports policy training, not reward models. " + "Run reward-model training locally." + ) + + # Auto-generate the model repo unless the user pinned one. cfg.policy is guaranteed + # set here (validate() raises if neither policy nor reward_model is configured, and + # reward-model runs are rejected above). + now = dt.datetime.now(dt.UTC) + repo_id = cfg.policy.repo_id or build_repo_id(username, cfg.job_name or cfg.policy.type, now) + cfg.policy.repo_id = repo_id + cfg.policy.push_to_hub = True + secrets: dict[str, str] = {"HF_TOKEN": token} if cfg.wandb.enable: wandb_key = resolve_wandb_api_key() @@ -301,15 +314,21 @@ def submit_to_hf(cfg: TrainPipelineConfig) -> None: print(f" Monitor: hf jobs logs {job_id}") print(f" Cancel: hf jobs cancel {job_id}") - original_sigint = signal.getsignal(signal.SIGINT) - signal.signal(signal.SIGINT, _detach) + # signal.signal only works on the main thread; when called from a worker thread + # (e.g. an orchestration framework) skip the Ctrl-C-detaches-instead-of-cancels + # handler rather than crashing with ValueError. + install_sigint = threading.current_thread() is threading.main_thread() + original_sigint = signal.getsignal(signal.SIGINT) if install_sigint else None + if install_sigint: + signal.signal(signal.SIGINT, _detach) try: # Timeout-based join so SIGINT is delivered to the main thread promptly. while poll_thread.is_alive(): poll_thread.join(timeout=0.5) log_thread.join(timeout=5) finally: - signal.signal(signal.SIGINT, original_sigint) + if install_sigint: + signal.signal(signal.SIGINT, original_sigint) if detached.is_set(): return diff --git a/tests/jobs/test_hf.py b/tests/jobs/test_hf.py index f6e17c278..798f83264 100644 --- a/tests/jobs/test_hf.py +++ b/tests/jobs/test_hf.py @@ -18,6 +18,7 @@ import threading from types import SimpleNamespace import draccus +import httpx import pytest from lerobot.configs.train import TrainPipelineConfig @@ -58,9 +59,9 @@ def test_poll_until_done_exits_when_done_already_set(monkeypatch): assert _poll_until_done("j", done, poll_interval=0.01) is None -def test_poll_until_done_gives_up_after_repeated_failures(monkeypatch): +def test_poll_until_done_gives_up_after_repeated_network_failures(monkeypatch): monkeypatch.setattr( - "lerobot.jobs.hf.inspect_job", lambda job_id: (_ for _ in ()).throw(RuntimeError("boom")) + "lerobot.jobs.hf.inspect_job", lambda job_id: (_ for _ in ()).throw(httpx.ConnectError("boom")) ) done = threading.Event() result = _poll_until_done("j", done, poll_interval=0.001, max_failures=3) @@ -68,6 +69,14 @@ def test_poll_until_done_gives_up_after_repeated_failures(monkeypatch): assert done.is_set() +def test_poll_until_done_propagates_programming_errors(monkeypatch): + """A bug (e.g. TypeError) must surface, not be silently retried as a transient failure.""" + monkeypatch.setattr("lerobot.jobs.hf.inspect_job", lambda job_id: (_ for _ in ()).throw(TypeError("bug"))) + done = threading.Event() + with pytest.raises(TypeError): + _poll_until_done("j", done, poll_interval=0.001, max_failures=3) + + def test_resolve_wandb_key_from_env(monkeypatch): monkeypatch.setenv("WANDB_API_KEY", "abc123") assert resolve_wandb_api_key() == "abc123" @@ -121,6 +130,23 @@ def _minimal_cfg(): ) +def test_validate_skips_repo_id_check_for_remote(): + """Remote runs auto-assign repo_id in submit_to_hf, so validate() must not demand it up front.""" + cfg = _minimal_cfg() # remote target, push_to_hub default True, no explicit repo_id + assert cfg.policy.repo_id is None + cfg.validate() # must not raise + + +def test_validate_requires_repo_id_for_local_push(): + """Local runs that push to the Hub still need an explicit repo_id.""" + cfg = draccus.parse( + TrainPipelineConfig, + args=["--dataset.repo_id", "u/d", "--policy.type", "act"], + ) + with pytest.raises(ValueError, match="repo_id"): + cfg.validate() + + def test_build_remote_config_applies_overrides(tmp_path): cfg = _minimal_cfg() dest = tmp_path / "train_config.json" @@ -192,7 +218,7 @@ def test_submit_requires_login(monkeypatch): def test_submit_passes_validation_and_submits(monkeypatch): - """Regression: repo_id must be set BEFORE cfg.validate() or validation raises.""" + """A type-based policy with no explicit repo_id is auto-assigned one and submitted.""" from unittest.mock import MagicMock # Patch get_token @@ -261,6 +287,27 @@ def test_submit_passes_validation_and_submits(monkeypatch): assert call["labels"].get("lerobot") == "true" +def test_submit_rejects_reward_model_training(monkeypatch): + """Remote training only supports policies; reward-model runs fail fast with a clear error.""" + monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok") + + class FakeHfApi: + def __init__(self, token=None): + pass + + def whoami(self, token=None): + return {"name": "alice"} + + monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi) + + cfg = _minimal_cfg() + cfg.reward_model = SimpleNamespace(type="reward") # marks this as reward-model training + monkeypatch.setattr(cfg, "validate", lambda: None) # skip pretrained-path resolution + + with pytest.raises(ValueError, match="reward model"): + submit_to_hf(cfg) + + @pytest.mark.timeout(15) def test_submit_returns_when_job_completes(monkeypatch): """Non-detach path must RETURN (not hang) once the job reaches a terminal stage."""