diff --git a/src/lerobot/configs/__init__.py b/src/lerobot/configs/__init__.py index be4491811..ca4062d1e 100644 --- a/src/lerobot/configs/__init__.py +++ b/src/lerobot/configs/__init__.py @@ -22,7 +22,7 @@ Import them directly: ``from lerobot.configs.train import TrainPipelineConfig`` """ from .dataset import DatasetRecordConfig -from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig +from .default import DatasetConfig, EvalConfig, JobConfig, PeftConfig, WandBConfig from .policies import PreTrainedConfig from .recipe import MessageTurn, TrainingRecipe, load_recipe from .types import ( @@ -50,6 +50,7 @@ __all__ = [ "DatasetRecordConfig", "DatasetConfig", "EvalConfig", + "JobConfig", "MessageTurn", "PeftConfig", "PreTrainedConfig", diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 621f0cdda..4fef22e1c 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -693,7 +693,7 @@ def _remote_target_in_argv() -> bool: """True when the CLI requests a remote HF Jobs run (--job.target=).""" import sys - from lerobot.configs.default import JobConfig + from lerobot.configs import JobConfig target = None args = sys.argv[1:] diff --git a/tests/jobs/test_job_config.py b/tests/jobs/test_job_config.py index d164497ad..7254e1fa1 100644 --- a/tests/jobs/test_job_config.py +++ b/tests/jobs/test_job_config.py @@ -15,7 +15,7 @@ import draccus import pytest -from lerobot.configs.default import JobConfig +from lerobot.configs import JobConfig from lerobot.configs.train import TrainPipelineConfig