From 71c827f892a9cf5806ff3851981c804a654baf46 Mon Sep 17 00:00:00 2001 From: Nicolas Rabault Date: Mon, 22 Jun 2026 15:43:52 +0200 Subject: [PATCH] feat(train): add JobConfig group, save_checkpoint_to_hub flag, Hub checkpoint helper Introduce a JobConfig draccus group on TrainPipelineConfig (--job.target/image/ timeout/detach/tags) whose is_remote property gates remote dispatch, plus a save_checkpoint_to_hub flag and validation. Add push_checkpoint_to_hub(), which uploads a saved checkpoint directory to the model repo under checkpoints// and creates the repo idempotently (private propagates from policy.private). --- src/lerobot/common/train_utils.py | 25 ++++++++++++ src/lerobot/configs/default.py | 28 ++++++++++++++ src/lerobot/configs/train.py | 12 +++++- tests/jobs/__init__.py | 0 tests/jobs/conftest.py | 17 ++++++++ tests/jobs/test_job_config.py | 64 +++++++++++++++++++++++++++++++ tests/utils/test_train_utils.py | 35 ++++++++++++++++- 7 files changed, 179 insertions(+), 2 deletions(-) create mode 100644 tests/jobs/__init__.py create mode 100644 tests/jobs/conftest.py create mode 100644 tests/jobs/test_job_config.py diff --git a/src/lerobot/common/train_utils.py b/src/lerobot/common/train_utils.py index 5ae593bb8..97488774e 100644 --- a/src/lerobot/common/train_utils.py +++ b/src/lerobot/common/train_utils.py @@ -283,3 +283,28 @@ def load_fsdp_optimizer_state(model, optimizer, checkpoint_dir: Path) -> None: with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, state_cfg, optim_cfg): sharded_osd = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=full_osd) optimizer.load_state_dict(sharded_osd) + + +def push_checkpoint_to_hub( + checkpoint_dir: Path, + repo_id: str, + *, + private: bool | None = None, +) -> None: + """Upload a saved checkpoint directory to the Hub under checkpoints//. + + Called once per save step when save_checkpoint_to_hub is enabled, so a + timed-out or crashed run still leaves recoverable checkpoints on the Hub. + The model repo is created idempotently. + """ + from huggingface_hub import HfApi + + api = HfApi() + api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True) + api.upload_folder( + folder_path=str(checkpoint_dir), + repo_id=repo_id, + repo_type="model", + path_in_repo=f"checkpoints/{checkpoint_dir.name}", + commit_message=f"checkpoint {checkpoint_dir.name}", + ) diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index b809e71d9..f72d6c8d1 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -123,3 +123,31 @@ class PeftConfig: # If None, the PEFT library defaults to alpha=8, which may dampen high-rank adapters. # Common values are r (alpha == rank) or 2*r. lora_alpha: int | None = None + + +@dataclass +class JobConfig: + # Where training runs. None (omitted) or "local" runs on this machine. + # Any other value is an HF Jobs flavor and submits the run to HF Jobs. + # List available flavors + pricing with `hf jobs hardware` command. + target: str | None = None + # Runtime image for the remote job (ignored for local runs). + image: str = "huggingface/lerobot-gpu:latest" + # Max wall-clock for the remote job as an HF Jobs duration string (e.g. "2h"). + # None (default) imposes no timeout — the job runs until the command finishes. + timeout: str | None = None + # Submit and exit instead of streaming the job logs in the foreground. + detach: bool = False + # Extra tags attached to the HF job and to any dataset this run pushes to the + # Hub. A "lerobot" tag is always added; e.g. --job.tags '["lelab"]' adds more. + tags: list[str] = field(default_factory=list) + + @staticmethod + def is_remote_target(target: str | None) -> bool: + """True when `target` names an HF Jobs flavor rather than a local run.""" + return target not in (None, "local") + + @property + def is_remote(self) -> bool: + """True when training should run on HF Jobs rather than this machine.""" + return self.is_remote_target(self.target) diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index bac1a946b..17707e120 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -30,7 +30,7 @@ from lerobot.utils.hub import HubMixin from lerobot.utils.sample_weighting import SampleWeightingConfig from . import parser -from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig +from .default import DatasetConfig, EvalConfig, JobConfig, PeftConfig, WandBConfig from .policies import PreTrainedConfig from .rewards import RewardModelConfig @@ -113,6 +113,13 @@ class TrainPipelineConfig(HubMixin): wandb: WandBConfig = field(default_factory=WandBConfig) peft: PeftConfig | None = None + # Where to run training (local default, or an HF Jobs flavor). See JobConfig. + job: JobConfig = field(default_factory=JobConfig) + # Push each saved checkpoint to the Hub (policy.repo_id) as it is written, not + # just the final model (useful to monitor progress mid-run). Optional; the + # final model is pushed regardless. Works the same locally and remotely. + save_checkpoint_to_hub: bool = False + # Sample weighting configuration (e.g., for RA-BC training) sample_weighting: SampleWeightingConfig | None = None @@ -211,6 +218,9 @@ class TrainPipelineConfig(HubMixin): if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id: raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.") + if self.save_checkpoint_to_hub and not (self.policy is not None and self.policy.repo_id): + raise ValueError("save_checkpoint_to_hub requires --policy.repo_id.") + @classmethod def __get_path_fields__(cls) -> list[str]: """Keys for draccus pretrained-path loading.""" diff --git a/tests/jobs/__init__.py b/tests/jobs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/jobs/conftest.py b/tests/jobs/conftest.py new file mode 100644 index 000000000..419d2f83f --- /dev/null +++ b/tests/jobs/conftest.py @@ -0,0 +1,17 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Importing concrete policy configs registers their draccus `--policy.type` +# choices (e.g. "act") so tests can parse them. +from lerobot.policies.act.configuration_act import ACTConfig # noqa: F401 diff --git a/tests/jobs/test_job_config.py b/tests/jobs/test_job_config.py new file mode 100644 index 000000000..d164497ad --- /dev/null +++ b/tests/jobs/test_job_config.py @@ -0,0 +1,64 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import draccus +import pytest + +from lerobot.configs.default import JobConfig +from lerobot.configs.train import TrainPipelineConfig + + +def test_jobconfig_defaults_are_local(): + cfg = JobConfig() + assert cfg.target is None + assert cfg.is_remote is False + assert cfg.image == "huggingface/lerobot-gpu:latest" + assert cfg.timeout is None + assert cfg.detach is False + + +def test_jobconfig_local_string_is_not_remote(): + assert JobConfig(target="local").is_remote is False + + +def test_jobconfig_flavor_is_remote(): + assert JobConfig(target="a10g-small").is_remote is True + + +def test_train_config_parses_job_target(): + parsed = draccus.parse( + TrainPipelineConfig, + args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"], + ) + assert parsed.job.target == "a10g-small" + assert parsed.job.is_remote is True + assert parsed.save_checkpoint_to_hub is False + + +def test_save_checkpoint_to_hub_requires_repo_id(): + cfg = draccus.parse( + TrainPipelineConfig, + args=[ + "--dataset.repo_id", + "u/d", + "--policy.type", + "act", + "--policy.push_to_hub", + "false", + "--save_checkpoint_to_hub", + "true", + ], + ) + with pytest.raises(ValueError, match="requires --policy.repo_id"): + cfg.validate() diff --git a/tests/utils/test_train_utils.py b/tests/utils/test_train_utils.py index e3705409b..be3f19231 100644 --- a/tests/utils/test_train_utils.py +++ b/tests/utils/test_train_utils.py @@ -15,7 +15,7 @@ # limitations under the License. from pathlib import Path -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch from lerobot.common.train_utils import ( get_step_checkpoint_dir, @@ -24,6 +24,7 @@ from lerobot.common.train_utils import ( load_training_num_processes, load_training_state, load_training_step, + push_checkpoint_to_hub, save_checkpoint, save_training_state, save_training_step, @@ -151,3 +152,35 @@ def test_load_training_state_skip_optimizer(tmp_path, optimizer, scheduler): assert loaded_step == 10 assert loaded_optimizer is optimizer assert loaded_scheduler is scheduler + + +def test_push_checkpoint_to_hub_creates_repo_and_uploads(tmp_path, monkeypatch): + import huggingface_hub + + ckpt = tmp_path / "010000" + (ckpt / "pretrained_model").mkdir(parents=True) + api = MagicMock() + monkeypatch.setattr(huggingface_hub, "HfApi", lambda *a, **k: api) + push_checkpoint_to_hub(ckpt, "user/run", private=True) + api.create_repo.assert_called_once() + assert api.create_repo.call_args.kwargs["private"] is True + assert api.create_repo.call_args.kwargs["repo_type"] == "model" + api.upload_folder.assert_called_once() + kwargs = api.upload_folder.call_args.kwargs + assert kwargs["repo_id"] == "user/run" + assert kwargs["repo_type"] == "model" + assert kwargs["path_in_repo"] == "checkpoints/010000" + assert kwargs["folder_path"] == str(ckpt) + assert kwargs["commit_message"] == "checkpoint 010000" + + +def test_push_checkpoint_to_hub_defaults_to_hub_default_visibility(tmp_path, monkeypatch): + import huggingface_hub + + ckpt = tmp_path / "010000" + (ckpt / "pretrained_model").mkdir(parents=True) + api = MagicMock() + monkeypatch.setattr(huggingface_hub, "HfApi", lambda *a, **k: api) + push_checkpoint_to_hub(ckpt, "user/run") + api.create_repo.assert_called_once() + assert api.create_repo.call_args.kwargs["private"] is None