fix(jobs): address claude review findings on remote training

Resolve the claude[bot] review on #3856:

- Reject reward-model training under --job.target with a clear error instead
  of crashing on a None policy inside build_remote_config_file.
- Support --policy.path remote runs: validate() no longer requires repo_id for
  remote runs (it is auto-generated in submit_to_hf), and repo_id/push_to_hub
  are now set after validate() resolves the policy.
- Narrow the bare `except Exception` in _tail_logs/_poll_until_done to
  (OSError, httpx.HTTPError) so programming errors surface instead of being
  silently retried or counted as job failures.
- Install the SIGINT detach handler only on the main thread.
- Generate model repo timestamps in UTC.
This commit is contained in:
Nicolas Rabault
2026-06-25 16:11:06 +02:00
parent 6b64642bdb
commit ab69bc5f06
3 changed files with 92 additions and 19 deletions
+8 -1
View File
@@ -215,7 +215,14 @@ class TrainPipelineConfig(HubMixin):
self.optimizer = active_cfg.get_optimizer_preset()
self.scheduler = active_cfg.get_scheduler_preset()
if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id:
# Remote runs auto-generate the repo_id in submit_to_hf (the policy may only be
# resolved here, from --policy.path), so don't demand it up front for them.
if (
hasattr(active_cfg, "push_to_hub")
and active_cfg.push_to_hub
and not active_cfg.repo_id
and not self.job.is_remote
):
raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.")
if self.save_checkpoint_to_hub and not (self.policy is not None and self.policy.repo_id):
+34 -15
View File
@@ -31,6 +31,7 @@ import threading
from pathlib import Path
from typing import TYPE_CHECKING
import httpx
from huggingface_hub import (
HfApi,
create_repo,
@@ -48,6 +49,12 @@ _SLUG_RE = re.compile(r"[^a-zA-Z0-9._-]+")
_TERMINAL_STAGES = {"COMPLETED", "CANCELED", "ERROR", "DELETED"}
# huggingface_hub 1.x runs on httpx: transient HTTP/transport failures surface as
# httpx.HTTPError and socket-level errors as OSError. Catching only these keeps real
# bugs (TypeError, AttributeError, ...) from being silently retried or counted as
# job failures.
_TRANSIENT_NET_ERRORS = (OSError, httpx.HTTPError)
# Always attached to remote jobs and pushed datasets so LeRobot-originated work
# is identifiable on the Hub; callers (e.g. LeLab) add their own via --job.tags.
LEROBOT_TAG = "lerobot"
@@ -170,7 +177,7 @@ def _tail_logs(
# the job terminal before we reconnect (avoids re-tailing the buffer).
if done.wait(3):
return
except Exception:
except _TRANSIENT_NET_ERRORS:
if done.wait(2):
return
@@ -200,7 +207,7 @@ def _poll_until_done(
status_holder["message"] = getattr(info.status, "message", None)
done.set()
return stage
except Exception:
except _TRANSIENT_NET_ERRORS:
failures += 1
if failures >= max_failures:
done.set()
@@ -226,18 +233,24 @@ def submit_to_hf(cfg: TrainPipelineConfig) -> None:
user_info = api.whoami(token=token)
username = user_info["name"]
now = dt.datetime.now()
if cfg.policy is not None:
base_name = cfg.job_name or cfg.policy.type
repo_id = cfg.policy.repo_id or build_repo_id(username, base_name, now)
cfg.policy.repo_id = repo_id
cfg.policy.push_to_hub = True
else:
# Path-based policy is resolved inside validate(); fall back to a generic slug.
repo_id = build_repo_id(username, cfg.job_name or "train", now)
# validate() resolves a `--policy.path=...` policy into cfg.policy and skips its
# repo_id requirement for remote runs (we assign one below), so it's safe to run first.
cfg.validate()
if cfg.is_reward_model_training:
raise ValueError(
"Remote training via --job.target only supports policy training, not reward models. "
"Run reward-model training locally."
)
# Auto-generate the model repo unless the user pinned one. cfg.policy is guaranteed
# set here (validate() raises if neither policy nor reward_model is configured, and
# reward-model runs are rejected above).
now = dt.datetime.now(dt.UTC)
repo_id = cfg.policy.repo_id or build_repo_id(username, cfg.job_name or cfg.policy.type, now)
cfg.policy.repo_id = repo_id
cfg.policy.push_to_hub = True
secrets: dict[str, str] = {"HF_TOKEN": token}
if cfg.wandb.enable:
wandb_key = resolve_wandb_api_key()
@@ -301,15 +314,21 @@ def submit_to_hf(cfg: TrainPipelineConfig) -> None:
print(f" Monitor: hf jobs logs {job_id}")
print(f" Cancel: hf jobs cancel {job_id}")
original_sigint = signal.getsignal(signal.SIGINT)
signal.signal(signal.SIGINT, _detach)
# signal.signal only works on the main thread; when called from a worker thread
# (e.g. an orchestration framework) skip the Ctrl-C-detaches-instead-of-cancels
# handler rather than crashing with ValueError.
install_sigint = threading.current_thread() is threading.main_thread()
original_sigint = signal.getsignal(signal.SIGINT) if install_sigint else None
if install_sigint:
signal.signal(signal.SIGINT, _detach)
try:
# Timeout-based join so SIGINT is delivered to the main thread promptly.
while poll_thread.is_alive():
poll_thread.join(timeout=0.5)
log_thread.join(timeout=5)
finally:
signal.signal(signal.SIGINT, original_sigint)
if install_sigint:
signal.signal(signal.SIGINT, original_sigint)
if detached.is_set():
return
+50 -3
View File
@@ -18,6 +18,7 @@ import threading
from types import SimpleNamespace
import draccus
import httpx
import pytest
from lerobot.configs.train import TrainPipelineConfig
@@ -58,9 +59,9 @@ def test_poll_until_done_exits_when_done_already_set(monkeypatch):
assert _poll_until_done("j", done, poll_interval=0.01) is None
def test_poll_until_done_gives_up_after_repeated_failures(monkeypatch):
def test_poll_until_done_gives_up_after_repeated_network_failures(monkeypatch):
monkeypatch.setattr(
"lerobot.jobs.hf.inspect_job", lambda job_id: (_ for _ in ()).throw(RuntimeError("boom"))
"lerobot.jobs.hf.inspect_job", lambda job_id: (_ for _ in ()).throw(httpx.ConnectError("boom"))
)
done = threading.Event()
result = _poll_until_done("j", done, poll_interval=0.001, max_failures=3)
@@ -68,6 +69,14 @@ def test_poll_until_done_gives_up_after_repeated_failures(monkeypatch):
assert done.is_set()
def test_poll_until_done_propagates_programming_errors(monkeypatch):
"""A bug (e.g. TypeError) must surface, not be silently retried as a transient failure."""
monkeypatch.setattr("lerobot.jobs.hf.inspect_job", lambda job_id: (_ for _ in ()).throw(TypeError("bug")))
done = threading.Event()
with pytest.raises(TypeError):
_poll_until_done("j", done, poll_interval=0.001, max_failures=3)
def test_resolve_wandb_key_from_env(monkeypatch):
monkeypatch.setenv("WANDB_API_KEY", "abc123")
assert resolve_wandb_api_key() == "abc123"
@@ -121,6 +130,23 @@ def _minimal_cfg():
)
def test_validate_skips_repo_id_check_for_remote():
"""Remote runs auto-assign repo_id in submit_to_hf, so validate() must not demand it up front."""
cfg = _minimal_cfg() # remote target, push_to_hub default True, no explicit repo_id
assert cfg.policy.repo_id is None
cfg.validate() # must not raise
def test_validate_requires_repo_id_for_local_push():
"""Local runs that push to the Hub still need an explicit repo_id."""
cfg = draccus.parse(
TrainPipelineConfig,
args=["--dataset.repo_id", "u/d", "--policy.type", "act"],
)
with pytest.raises(ValueError, match="repo_id"):
cfg.validate()
def test_build_remote_config_applies_overrides(tmp_path):
cfg = _minimal_cfg()
dest = tmp_path / "train_config.json"
@@ -192,7 +218,7 @@ def test_submit_requires_login(monkeypatch):
def test_submit_passes_validation_and_submits(monkeypatch):
"""Regression: repo_id must be set BEFORE cfg.validate() or validation raises."""
"""A type-based policy with no explicit repo_id is auto-assigned one and submitted."""
from unittest.mock import MagicMock
# Patch get_token
@@ -261,6 +287,27 @@ def test_submit_passes_validation_and_submits(monkeypatch):
assert call["labels"].get("lerobot") == "true"
def test_submit_rejects_reward_model_training(monkeypatch):
"""Remote training only supports policies; reward-model runs fail fast with a clear error."""
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
class FakeHfApi:
def __init__(self, token=None):
pass
def whoami(self, token=None):
return {"name": "alice"}
monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi)
cfg = _minimal_cfg()
cfg.reward_model = SimpleNamespace(type="reward") # marks this as reward-model training
monkeypatch.setattr(cfg, "validate", lambda: None) # skip pretrained-path resolution
with pytest.raises(ValueError, match="reward model"):
submit_to_hf(cfg)
@pytest.mark.timeout(15)
def test_submit_returns_when_job_completes(monkeypatch):
"""Non-detach path must RETURN (not hang) once the job reaches a terminal stage."""