mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-23 19:27:08 +00:00
05fddeb2ba
When --job.target names a GPU flavor, train() dispatches to lerobot.jobs.submit_to_hf instead of training locally: it authenticates, ensures the dataset is on the Hub (pushing a local-only one privately), serializes a pod-compatible train_config.json (strips client-only fields, points at the model repo), submits via HfApi.run_job with HF_TOKEN/WANDB_API_KEY secrets, then streams logs and finishes when the model is pushed. Wires push_checkpoint_to_hub into the training loop behind save_checkpoint_to_hub, and tags jobs/datasets/model with 'lerobot' + --job.tags.
61 lines
2.1 KiB
Python
61 lines
2.1 KiB
Python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import sys
|
|
|
|
import draccus
|
|
|
|
from lerobot.configs.train import TrainPipelineConfig
|
|
from lerobot.policies.act.configuration_act import ACTConfig # noqa: F401 (registers --policy.type act)
|
|
from lerobot.scripts.lerobot_train import _remote_target_in_argv, train
|
|
|
|
|
|
def _set_argv(monkeypatch, *args):
|
|
monkeypatch.setattr(sys, "argv", ["lerobot-train", *args])
|
|
|
|
|
|
def test_remote_target_detected_space_separated(monkeypatch):
|
|
_set_argv(monkeypatch, "--policy.type", "act", "--job.target", "a10g-small")
|
|
assert _remote_target_in_argv() is True
|
|
|
|
|
|
def test_remote_target_detected_equals(monkeypatch):
|
|
_set_argv(monkeypatch, "--job.target=t4-small")
|
|
assert _remote_target_in_argv() is True
|
|
|
|
|
|
def test_local_string_is_not_remote(monkeypatch):
|
|
_set_argv(monkeypatch, "--job.target", "local")
|
|
assert _remote_target_in_argv() is False
|
|
|
|
|
|
def test_no_target_is_not_remote(monkeypatch):
|
|
_set_argv(monkeypatch, "--policy.type", "act")
|
|
assert _remote_target_in_argv() is False
|
|
|
|
|
|
def test_train_dispatches_to_submit_when_remote(monkeypatch):
|
|
"""A remote --job.target short-circuits train() to the HF Jobs submitter."""
|
|
import lerobot.jobs
|
|
|
|
captured = []
|
|
monkeypatch.setattr(lerobot.jobs, "submit_to_hf", lambda cfg: captured.append(cfg) or "submitted")
|
|
cfg = draccus.parse(
|
|
TrainPipelineConfig,
|
|
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
|
)
|
|
# Returns the submitter's result and never enters the local training path.
|
|
assert train(cfg) == "submitted"
|
|
assert captured == [cfg]
|