Files
lerobot/tests/scripts/test_train_remote_dispatch.py
T
Nicolas Rabault 05fddeb2ba feat(train): run training remotely on HF Jobs via --job.target
When --job.target names a GPU flavor, train() dispatches to lerobot.jobs.submit_to_hf
instead of training locally: it authenticates, ensures the dataset is on the Hub
(pushing a local-only one privately), serializes a pod-compatible train_config.json
(strips client-only fields, points at the model repo), submits via HfApi.run_job
with HF_TOKEN/WANDB_API_KEY secrets, then streams logs and finishes when the model
is pushed. Wires push_checkpoint_to_hub into the training loop behind
save_checkpoint_to_hub, and tags jobs/datasets/model with 'lerobot' + --job.tags.
2026-06-22 16:24:05 +02:00

61 lines
2.1 KiB
Python

# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import draccus
from lerobot.configs.train import TrainPipelineConfig
from lerobot.policies.act.configuration_act import ACTConfig # noqa: F401 (registers --policy.type act)
from lerobot.scripts.lerobot_train import _remote_target_in_argv, train
def _set_argv(monkeypatch, *args):
monkeypatch.setattr(sys, "argv", ["lerobot-train", *args])
def test_remote_target_detected_space_separated(monkeypatch):
_set_argv(monkeypatch, "--policy.type", "act", "--job.target", "a10g-small")
assert _remote_target_in_argv() is True
def test_remote_target_detected_equals(monkeypatch):
_set_argv(monkeypatch, "--job.target=t4-small")
assert _remote_target_in_argv() is True
def test_local_string_is_not_remote(monkeypatch):
_set_argv(monkeypatch, "--job.target", "local")
assert _remote_target_in_argv() is False
def test_no_target_is_not_remote(monkeypatch):
_set_argv(monkeypatch, "--policy.type", "act")
assert _remote_target_in_argv() is False
def test_train_dispatches_to_submit_when_remote(monkeypatch):
"""A remote --job.target short-circuits train() to the HF Jobs submitter."""
import lerobot.jobs
captured = []
monkeypatch.setattr(lerobot.jobs, "submit_to_hf", lambda cfg: captured.append(cfg) or "submitted")
cfg = draccus.parse(
TrainPipelineConfig,
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
)
# Returns the submitter's result and never enters the local training path.
assert train(cfg) == "submitted"
assert captured == [cfg]