mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-23 11:17:02 +00:00
feat(train): add JobConfig group, save_checkpoint_to_hub flag, Hub checkpoint helper
Introduce a JobConfig draccus group on TrainPipelineConfig (--job.target/image/ timeout/detach/tags) whose is_remote property gates remote dispatch, plus a save_checkpoint_to_hub flag and validation. Add push_checkpoint_to_hub(), which uploads a saved checkpoint directory to the model repo under checkpoints/<step>/ and creates the repo idempotently (private propagates from policy.private).
This commit is contained in:
@@ -283,3 +283,28 @@ def load_fsdp_optimizer_state(model, optimizer, checkpoint_dir: Path) -> None:
|
||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, state_cfg, optim_cfg):
|
||||
sharded_osd = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=full_osd)
|
||||
optimizer.load_state_dict(sharded_osd)
|
||||
|
||||
|
||||
def push_checkpoint_to_hub(
|
||||
checkpoint_dir: Path,
|
||||
repo_id: str,
|
||||
*,
|
||||
private: bool | None = None,
|
||||
) -> None:
|
||||
"""Upload a saved checkpoint directory to the Hub under checkpoints/<name>/.
|
||||
|
||||
Called once per save step when save_checkpoint_to_hub is enabled, so a
|
||||
timed-out or crashed run still leaves recoverable checkpoints on the Hub.
|
||||
The model repo is created idempotently.
|
||||
"""
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True)
|
||||
api.upload_folder(
|
||||
folder_path=str(checkpoint_dir),
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
path_in_repo=f"checkpoints/{checkpoint_dir.name}",
|
||||
commit_message=f"checkpoint {checkpoint_dir.name}",
|
||||
)
|
||||
|
||||
@@ -123,3 +123,31 @@ class PeftConfig:
|
||||
# If None, the PEFT library defaults to alpha=8, which may dampen high-rank adapters.
|
||||
# Common values are r (alpha == rank) or 2*r.
|
||||
lora_alpha: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class JobConfig:
|
||||
# Where training runs. None (omitted) or "local" runs on this machine.
|
||||
# Any other value is an HF Jobs flavor and submits the run to HF Jobs.
|
||||
# List available flavors + pricing with `hf jobs hardware` command.
|
||||
target: str | None = None
|
||||
# Runtime image for the remote job (ignored for local runs).
|
||||
image: str = "huggingface/lerobot-gpu:latest"
|
||||
# Max wall-clock for the remote job as an HF Jobs duration string (e.g. "2h").
|
||||
# None (default) imposes no timeout — the job runs until the command finishes.
|
||||
timeout: str | None = None
|
||||
# Submit and exit instead of streaming the job logs in the foreground.
|
||||
detach: bool = False
|
||||
# Extra tags attached to the HF job and to any dataset this run pushes to the
|
||||
# Hub. A "lerobot" tag is always added; e.g. --job.tags '["lelab"]' adds more.
|
||||
tags: list[str] = field(default_factory=list)
|
||||
|
||||
@staticmethod
|
||||
def is_remote_target(target: str | None) -> bool:
|
||||
"""True when `target` names an HF Jobs flavor rather than a local run."""
|
||||
return target not in (None, "local")
|
||||
|
||||
@property
|
||||
def is_remote(self) -> bool:
|
||||
"""True when training should run on HF Jobs rather than this machine."""
|
||||
return self.is_remote_target(self.target)
|
||||
|
||||
@@ -30,7 +30,7 @@ from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||
|
||||
from . import parser
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .default import DatasetConfig, EvalConfig, JobConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .rewards import RewardModelConfig
|
||||
|
||||
@@ -113,6 +113,13 @@ class TrainPipelineConfig(HubMixin):
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
peft: PeftConfig | None = None
|
||||
|
||||
# Where to run training (local default, or an HF Jobs flavor). See JobConfig.
|
||||
job: JobConfig = field(default_factory=JobConfig)
|
||||
# Push each saved checkpoint to the Hub (policy.repo_id) as it is written, not
|
||||
# just the final model (useful to monitor progress mid-run). Optional; the
|
||||
# final model is pushed regardless. Works the same locally and remotely.
|
||||
save_checkpoint_to_hub: bool = False
|
||||
|
||||
# Sample weighting configuration (e.g., for RA-BC training)
|
||||
sample_weighting: SampleWeightingConfig | None = None
|
||||
|
||||
@@ -211,6 +218,9 @@ class TrainPipelineConfig(HubMixin):
|
||||
if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id:
|
||||
raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.")
|
||||
|
||||
if self.save_checkpoint_to_hub and not (self.policy is not None and self.policy.repo_id):
|
||||
raise ValueError("save_checkpoint_to_hub requires --policy.repo_id.")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""Keys for draccus pretrained-path loading."""
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Importing concrete policy configs registers their draccus `--policy.type`
|
||||
# choices (e.g. "act") so tests can parse them.
|
||||
from lerobot.policies.act.configuration_act import ACTConfig # noqa: F401
|
||||
@@ -0,0 +1,64 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import draccus
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.default import JobConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
def test_jobconfig_defaults_are_local():
|
||||
cfg = JobConfig()
|
||||
assert cfg.target is None
|
||||
assert cfg.is_remote is False
|
||||
assert cfg.image == "huggingface/lerobot-gpu:latest"
|
||||
assert cfg.timeout is None
|
||||
assert cfg.detach is False
|
||||
|
||||
|
||||
def test_jobconfig_local_string_is_not_remote():
|
||||
assert JobConfig(target="local").is_remote is False
|
||||
|
||||
|
||||
def test_jobconfig_flavor_is_remote():
|
||||
assert JobConfig(target="a10g-small").is_remote is True
|
||||
|
||||
|
||||
def test_train_config_parses_job_target():
|
||||
parsed = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
assert parsed.job.target == "a10g-small"
|
||||
assert parsed.job.is_remote is True
|
||||
assert parsed.save_checkpoint_to_hub is False
|
||||
|
||||
|
||||
def test_save_checkpoint_to_hub_requires_repo_id():
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=[
|
||||
"--dataset.repo_id",
|
||||
"u/d",
|
||||
"--policy.type",
|
||||
"act",
|
||||
"--policy.push_to_hub",
|
||||
"false",
|
||||
"--save_checkpoint_to_hub",
|
||||
"true",
|
||||
],
|
||||
)
|
||||
with pytest.raises(ValueError, match="requires --policy.repo_id"):
|
||||
cfg.validate()
|
||||
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from lerobot.common.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
@@ -24,6 +24,7 @@ from lerobot.common.train_utils import (
|
||||
load_training_num_processes,
|
||||
load_training_state,
|
||||
load_training_step,
|
||||
push_checkpoint_to_hub,
|
||||
save_checkpoint,
|
||||
save_training_state,
|
||||
save_training_step,
|
||||
@@ -151,3 +152,35 @@ def test_load_training_state_skip_optimizer(tmp_path, optimizer, scheduler):
|
||||
assert loaded_step == 10
|
||||
assert loaded_optimizer is optimizer
|
||||
assert loaded_scheduler is scheduler
|
||||
|
||||
|
||||
def test_push_checkpoint_to_hub_creates_repo_and_uploads(tmp_path, monkeypatch):
|
||||
import huggingface_hub
|
||||
|
||||
ckpt = tmp_path / "010000"
|
||||
(ckpt / "pretrained_model").mkdir(parents=True)
|
||||
api = MagicMock()
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", lambda *a, **k: api)
|
||||
push_checkpoint_to_hub(ckpt, "user/run", private=True)
|
||||
api.create_repo.assert_called_once()
|
||||
assert api.create_repo.call_args.kwargs["private"] is True
|
||||
assert api.create_repo.call_args.kwargs["repo_type"] == "model"
|
||||
api.upload_folder.assert_called_once()
|
||||
kwargs = api.upload_folder.call_args.kwargs
|
||||
assert kwargs["repo_id"] == "user/run"
|
||||
assert kwargs["repo_type"] == "model"
|
||||
assert kwargs["path_in_repo"] == "checkpoints/010000"
|
||||
assert kwargs["folder_path"] == str(ckpt)
|
||||
assert kwargs["commit_message"] == "checkpoint 010000"
|
||||
|
||||
|
||||
def test_push_checkpoint_to_hub_defaults_to_hub_default_visibility(tmp_path, monkeypatch):
|
||||
import huggingface_hub
|
||||
|
||||
ckpt = tmp_path / "010000"
|
||||
(ckpt / "pretrained_model").mkdir(parents=True)
|
||||
api = MagicMock()
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", lambda *a, **k: api)
|
||||
push_checkpoint_to_hub(ckpt, "user/run")
|
||||
api.create_repo.assert_called_once()
|
||||
assert api.create_repo.call_args.kwargs["private"] is None
|
||||
|
||||
Reference in New Issue
Block a user