From f2b1e62dda9971818e9e60a9227bfcb952b2248e Mon Sep 17 00:00:00 2001 From: Nicolas Rabault Date: Mon, 29 Jun 2026 09:11:54 +0200 Subject: [PATCH] refactor(train): hoist submit_to_hf import to module top The `from lerobot.jobs import submit_to_hf` was a function-local import in train(); it pulls no heavy/optional deps and has no circular-import risk, so move it to the top-level import block. --- src/lerobot/scripts/lerobot_train.py | 3 +-- tests/scripts/test_train_remote_dispatch.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index f98d062a7..52c6783fe 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -51,6 +51,7 @@ from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state from lerobot.datasets.factory import make_train_eval_datasets from lerobot.envs import close_envs, make_env, make_env_pre_post_processors +from lerobot.jobs import submit_to_hf from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors from lerobot.rewards import make_reward_pre_post_processors @@ -190,8 +191,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): accelerator: Optional Accelerator instance. If None, one will be created automatically. """ if cfg.job.is_remote: - from lerobot.jobs import submit_to_hf - return submit_to_hf(cfg) from lerobot.utils.import_utils import require_package diff --git a/tests/scripts/test_train_remote_dispatch.py b/tests/scripts/test_train_remote_dispatch.py index 841bc798a..50431da9e 100644 --- a/tests/scripts/test_train_remote_dispatch.py +++ b/tests/scripts/test_train_remote_dispatch.py @@ -54,10 +54,10 @@ def test_no_target_is_not_remote(monkeypatch): def test_train_dispatches_to_submit_when_remote(monkeypatch): """A remote --job.target short-circuits train() to the HF Jobs submitter.""" - import lerobot.jobs + import lerobot.scripts.lerobot_train as train_module captured = [] - monkeypatch.setattr(lerobot.jobs, "submit_to_hf", lambda cfg: captured.append(cfg) or "submitted") + monkeypatch.setattr(train_module, "submit_to_hf", lambda cfg: captured.append(cfg) or "submitted") cfg = draccus.parse( TrainPipelineConfig, args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],