diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index f98d062a7..52c6783fe 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -51,6 +51,7 @@ from lerobot.configs.train import TrainPipelineConfig from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state from lerobot.datasets.factory import make_train_eval_datasets from lerobot.envs import close_envs, make_env, make_env_pre_post_processors +from lerobot.jobs import submit_to_hf from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors from lerobot.rewards import make_reward_pre_post_processors @@ -190,8 +191,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): accelerator: Optional Accelerator instance. If None, one will be created automatically. """ if cfg.job.is_remote: - from lerobot.jobs import submit_to_hf - return submit_to_hf(cfg) from lerobot.utils.import_utils import require_package diff --git a/tests/scripts/test_train_remote_dispatch.py b/tests/scripts/test_train_remote_dispatch.py index 841bc798a..50431da9e 100644 --- a/tests/scripts/test_train_remote_dispatch.py +++ b/tests/scripts/test_train_remote_dispatch.py @@ -54,10 +54,10 @@ def test_no_target_is_not_remote(monkeypatch): def test_train_dispatches_to_submit_when_remote(monkeypatch): """A remote --job.target short-circuits train() to the HF Jobs submitter.""" - import lerobot.jobs + import lerobot.scripts.lerobot_train as train_module captured = [] - monkeypatch.setattr(lerobot.jobs, "submit_to_hf", lambda cfg: captured.append(cfg) or "submitted") + monkeypatch.setattr(train_module, "submit_to_hf", lambda cfg: captured.append(cfg) or "submitted") cfg = draccus.parse( TrainPipelineConfig, args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],