feat(train): add JobConfig group, save_checkpoint_to_hub flag, Hub checkpoint helper

Introduce a JobConfig draccus group on TrainPipelineConfig (--job.target/image/
timeout/detach/tags) whose is_remote property gates remote dispatch, plus a
save_checkpoint_to_hub flag and validation. Add push_checkpoint_to_hub(), which
uploads a saved checkpoint directory to the model repo under checkpoints/<step>/
and creates the repo idempotently (private propagates from policy.private).
This commit is contained in:
Nicolas Rabault
2026-06-22 15:43:52 +02:00
parent 73782447f2
commit 71c827f892
7 changed files with 179 additions and 2 deletions
+25
View File
@@ -283,3 +283,28 @@ def load_fsdp_optimizer_state(model, optimizer, checkpoint_dir: Path) -> None:
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, state_cfg, optim_cfg):
sharded_osd = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=full_osd)
optimizer.load_state_dict(sharded_osd)
def push_checkpoint_to_hub(
checkpoint_dir: Path,
repo_id: str,
*,
private: bool | None = None,
) -> None:
"""Upload a saved checkpoint directory to the Hub under checkpoints/<name>/.
Called once per save step when save_checkpoint_to_hub is enabled, so a
timed-out or crashed run still leaves recoverable checkpoints on the Hub.
The model repo is created idempotently.
"""
from huggingface_hub import HfApi
api = HfApi()
api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True)
api.upload_folder(
folder_path=str(checkpoint_dir),
repo_id=repo_id,
repo_type="model",
path_in_repo=f"checkpoints/{checkpoint_dir.name}",
commit_message=f"checkpoint {checkpoint_dir.name}",
)
+28
View File
@@ -123,3 +123,31 @@ class PeftConfig:
# If None, the PEFT library defaults to alpha=8, which may dampen high-rank adapters.
# Common values are r (alpha == rank) or 2*r.
lora_alpha: int | None = None
@dataclass
class JobConfig:
# Where training runs. None (omitted) or "local" runs on this machine.
# Any other value is an HF Jobs flavor and submits the run to HF Jobs.
# List available flavors + pricing with `hf jobs hardware` command.
target: str | None = None
# Runtime image for the remote job (ignored for local runs).
image: str = "huggingface/lerobot-gpu:latest"
# Max wall-clock for the remote job as an HF Jobs duration string (e.g. "2h").
# None (default) imposes no timeout — the job runs until the command finishes.
timeout: str | None = None
# Submit and exit instead of streaming the job logs in the foreground.
detach: bool = False
# Extra tags attached to the HF job and to any dataset this run pushes to the
# Hub. A "lerobot" tag is always added; e.g. --job.tags '["lelab"]' adds more.
tags: list[str] = field(default_factory=list)
@staticmethod
def is_remote_target(target: str | None) -> bool:
"""True when `target` names an HF Jobs flavor rather than a local run."""
return target not in (None, "local")
@property
def is_remote(self) -> bool:
"""True when training should run on HF Jobs rather than this machine."""
return self.is_remote_target(self.target)
+11 -1
View File
@@ -30,7 +30,7 @@ from lerobot.utils.hub import HubMixin
from lerobot.utils.sample_weighting import SampleWeightingConfig
from . import parser
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
from .default import DatasetConfig, EvalConfig, JobConfig, PeftConfig, WandBConfig
from .policies import PreTrainedConfig
from .rewards import RewardModelConfig
@@ -113,6 +113,13 @@ class TrainPipelineConfig(HubMixin):
wandb: WandBConfig = field(default_factory=WandBConfig)
peft: PeftConfig | None = None
# Where to run training (local default, or an HF Jobs flavor). See JobConfig.
job: JobConfig = field(default_factory=JobConfig)
# Push each saved checkpoint to the Hub (policy.repo_id) as it is written, not
# just the final model (useful to monitor progress mid-run). Optional; the
# final model is pushed regardless. Works the same locally and remotely.
save_checkpoint_to_hub: bool = False
# Sample weighting configuration (e.g., for RA-BC training)
sample_weighting: SampleWeightingConfig | None = None
@@ -211,6 +218,9 @@ class TrainPipelineConfig(HubMixin):
if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id:
raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.")
if self.save_checkpoint_to_hub and not (self.policy is not None and self.policy.repo_id):
raise ValueError("save_checkpoint_to_hub requires --policy.repo_id.")
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""Keys for draccus pretrained-path loading."""
View File
+17
View File
@@ -0,0 +1,17 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Importing concrete policy configs registers their draccus `--policy.type`
# choices (e.g. "act") so tests can parse them.
from lerobot.policies.act.configuration_act import ACTConfig # noqa: F401
+64
View File
@@ -0,0 +1,64 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import draccus
import pytest
from lerobot.configs.default import JobConfig
from lerobot.configs.train import TrainPipelineConfig
def test_jobconfig_defaults_are_local():
cfg = JobConfig()
assert cfg.target is None
assert cfg.is_remote is False
assert cfg.image == "huggingface/lerobot-gpu:latest"
assert cfg.timeout is None
assert cfg.detach is False
def test_jobconfig_local_string_is_not_remote():
assert JobConfig(target="local").is_remote is False
def test_jobconfig_flavor_is_remote():
assert JobConfig(target="a10g-small").is_remote is True
def test_train_config_parses_job_target():
parsed = draccus.parse(
TrainPipelineConfig,
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
)
assert parsed.job.target == "a10g-small"
assert parsed.job.is_remote is True
assert parsed.save_checkpoint_to_hub is False
def test_save_checkpoint_to_hub_requires_repo_id():
cfg = draccus.parse(
TrainPipelineConfig,
args=[
"--dataset.repo_id",
"u/d",
"--policy.type",
"act",
"--policy.push_to_hub",
"false",
"--save_checkpoint_to_hub",
"true",
],
)
with pytest.raises(ValueError, match="requires --policy.repo_id"):
cfg.validate()
+34 -1
View File
@@ -15,7 +15,7 @@
# limitations under the License.
from pathlib import Path
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch
from lerobot.common.train_utils import (
get_step_checkpoint_dir,
@@ -24,6 +24,7 @@ from lerobot.common.train_utils import (
load_training_num_processes,
load_training_state,
load_training_step,
push_checkpoint_to_hub,
save_checkpoint,
save_training_state,
save_training_step,
@@ -151,3 +152,35 @@ def test_load_training_state_skip_optimizer(tmp_path, optimizer, scheduler):
assert loaded_step == 10
assert loaded_optimizer is optimizer
assert loaded_scheduler is scheduler
def test_push_checkpoint_to_hub_creates_repo_and_uploads(tmp_path, monkeypatch):
import huggingface_hub
ckpt = tmp_path / "010000"
(ckpt / "pretrained_model").mkdir(parents=True)
api = MagicMock()
monkeypatch.setattr(huggingface_hub, "HfApi", lambda *a, **k: api)
push_checkpoint_to_hub(ckpt, "user/run", private=True)
api.create_repo.assert_called_once()
assert api.create_repo.call_args.kwargs["private"] is True
assert api.create_repo.call_args.kwargs["repo_type"] == "model"
api.upload_folder.assert_called_once()
kwargs = api.upload_folder.call_args.kwargs
assert kwargs["repo_id"] == "user/run"
assert kwargs["repo_type"] == "model"
assert kwargs["path_in_repo"] == "checkpoints/010000"
assert kwargs["folder_path"] == str(ckpt)
assert kwargs["commit_message"] == "checkpoint 010000"
def test_push_checkpoint_to_hub_defaults_to_hub_default_visibility(tmp_path, monkeypatch):
import huggingface_hub
ckpt = tmp_path / "010000"
(ckpt / "pretrained_model").mkdir(parents=True)
api = MagicMock()
monkeypatch.setattr(huggingface_hub, "HfApi", lambda *a, **k: api)
push_checkpoint_to_hub(ckpt, "user/run")
api.create_repo.assert_called_once()
assert api.create_repo.call_args.kwargs["private"] is None