mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-25 20:27:05 +00:00
Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 204228985c | |||
| 527f7a45c2 | |||
| 651c113cd3 | |||
| 838ab9e234 | |||
| 955b172585 | |||
| 6256e69c29 | |||
| d09842b734 | |||
| 6e9d699710 | |||
| ab69bc5f06 | |||
| 6b64642bdb | |||
| 4efa9da0d9 | |||
| 71a89d30f0 | |||
| 8a3a411af6 | |||
| 5cf72ec9d4 | |||
| 79fd82443b | |||
| 2d9e286f18 | |||
| 30cc3d59f5 | |||
| 6ad1e6b6ae | |||
| 79f2eafcc6 | |||
| 60cbe71857 | |||
| ed8694c67f | |||
| 3bbdad8442 | |||
| 05fddeb2ba | |||
| 71c827f892 |
+1
-1
@@ -138,7 +138,7 @@ lerobot-replay --robot.type=so101_follower --robot.port=<FOLLOWER_PORT> --robot.
|
||||
--dataset.repo_id=${HF_USER}/my_task --dataset.episode=0
|
||||
```
|
||||
|
||||
**4.9 Train** (default: ACT — fastest, lowest memory). Apple silicon: `--policy.device=mps`. See §6/§7 for policy and duration.
|
||||
**4.9 Train** (default: ACT — fastest, lowest memory). Apple silicon: `--policy.device=mps`. No local GPU? Add `--job.target=<flavor>` (e.g. `a10g-small`, list them with `hf jobs hardware`) to run on Hugging Face Jobs instead. See §6/§7 for policy and duration.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
|
||||
@@ -120,6 +120,14 @@ lerobot-train \
|
||||
--steps=20000
|
||||
```
|
||||
|
||||
No local GPU? Add `--job.target=<flavor>` (e.g. `a10g-small`) to either command and `lerobot-train` runs it on [Hugging Face Jobs](https://huggingface.co/docs/hub/jobs) instead — it uploads a local-only dataset for you and pushes the trained model. List flavors with `hf jobs hardware`.
|
||||
|
||||
To resume, point `--config_path` at a checkpoint and add `--resume=true`. It accepts a local path or a Hub repo id (the latest checkpoint is fetched), and works locally or on a job by adding `--job.target=<flavor>`:
|
||||
|
||||
```bash
|
||||
lerobot-train --config_path=${HF_USER}/policy_test --resume=true --job.target=a10g-small
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
Inference means running the trained policy/model on a robot. For that we use `lerobot-rollout`. You will need to provide a path to your policy. It can be a local path or a path to Hugging Face for example "lerobot/folding_latest". Your cameras configuration needs to match what was used when collecting the dataset. Duration is in seconds if unspecified, it will run forever.
|
||||
|
||||
@@ -96,3 +96,4 @@ Notes:
|
||||
- The leading `nvidia-smi` is a quick sanity check that CUDA is visible inside the container — useful to fail fast if the flavor or driver mismatched.
|
||||
- The default Job timeout is 30 minutes; pass `--timeout 4h` (or longer) for real training.
|
||||
- `--flavor` maps onto the table above: `t4-small`/`t4-medium` (T4, ACT only), `l4x1`/`l4x4` (L4 24 GB), `a10g-small/large/largex2/largex4` (A10G 24 GB scaled out), `a100-large` (A100). For the current full catalogue + pricing see [https://huggingface.co/docs/hub/jobs](https://huggingface.co/docs/hub/jobs).
|
||||
- Prefer not to write the `hf jobs run` wrapper yourself? `lerobot-train` can submit the job for you: just add `--job.target=<flavor>` to a normal training command and it handles dataset upload, log streaming, and the final model push. See the [imitation-learning training guide](./il_robots).
|
||||
|
||||
@@ -506,6 +506,12 @@ lerobot-train \
|
||||
--resume=true
|
||||
```
|
||||
|
||||
`--config_path` also accepts a **Hub repo id**: if a run pushed its checkpoints to the Hub (with `--save_checkpoint_to_hub=true`), you can resume straight from the repo — its latest checkpoint is downloaded and training continues, restoring the optimizer, scheduler, step counter and data order:
|
||||
|
||||
```bash
|
||||
lerobot-train --config_path=${HF_USER}/my_policy --resume=true
|
||||
```
|
||||
|
||||
If you do not want to push your model to the hub after training use `--policy.push_to_hub=false`.
|
||||
|
||||
Additionally you can provide extra `tags` or specify a `license` for your model or make the model repo `private` by adding this: `--policy.private=true --policy.tags=\[ppo,rl\] --policy.license=mit`
|
||||
@@ -518,7 +524,9 @@ If your local computer doesn't have a powerful GPU you could utilize Google Cola
|
||||
|
||||
Hugging Face jobs let's you easily select hardware and run the training in the cloud. So if you don't have a powerful GPU or you need more VRAM or just want to train a model much faster use HF Jobs! It's pay as you go and you simply pay for each second of use, you can see the pricing and additional information [here](https://huggingface.co/docs/hub/jobs).
|
||||
|
||||
To run the training use this command:
|
||||
> **Tip:** if you just want to launch a standard training run, you can skip building the command below and use the integrated **Train on HF Jobs via `--job.target`** flow described further down — `lerobot-train` then submits the job, uploads a local-only dataset for you, and streams the logs.
|
||||
|
||||
To run the training manually use this command:
|
||||
|
||||
<hfoptions id="train_with_hf_jobs">
|
||||
<hfoption id="Command">
|
||||
@@ -591,6 +599,51 @@ Once the training is started you can go to [Jobs](https://huggingface.co/setting
|
||||
|
||||
After training the model will be pushed to hub and you can use it as any other model with LeRobot.
|
||||
|
||||
#### Train on HF Jobs via `--job.target` (integrated CLI)
|
||||
|
||||
`lerobot-train` runs locally by default. To run on a HuggingFace GPU without constructing the Docker command yourself, pass `--job.target` with a hardware flavor name:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/so101_test \
|
||||
--policy.type=act \
|
||||
--policy.repo_id=${HF_USER}/my_policy \
|
||||
--job.target=a10g-small
|
||||
```
|
||||
|
||||
List available flavors and pricing with `hf jobs hardware`. The run streams its logs to your terminal; press Ctrl-C to detach (the job keeps running in the cloud). Re-attach or cancel with:
|
||||
|
||||
```bash
|
||||
hf jobs logs <job-id>
|
||||
hf jobs cancel <job-id>
|
||||
```
|
||||
|
||||
If your dataset exists only locally (not yet on the Hub), it is automatically pushed to a **private** Hub repo so the job can download it by `repo_id` (nothing is made public). The trained model is pushed to the model repo at the end of the run. To also push every intermediate checkpoint to the Hub as it is saved (so you can monitor progress mid-run), add `--save_checkpoint_to_hub=true` — this requires a runtime image that includes this feature.
|
||||
|
||||
Every job (and any dataset pushed by the run) is tagged `lerobot` so it's easy to find on the Hub. Add your own with `--job.tags '["my-tag"]'`.
|
||||
|
||||
By default the job is capped at `2d` (48h) of wall-clock. Override it with an HF Jobs duration string, e.g. `--job.timeout=4h` to fail faster or `--job.timeout=7d` for a longer run.
|
||||
|
||||
> **Note:** the model repo is created up front (it holds the staged training config the job runs from). If a run fails before the model is pushed, that repo is left on the Hub so you can inspect it — it is not deleted automatically, so repeated failures can leave empty repos behind. Remove one with `hf repo delete <repo-id>`.
|
||||
|
||||
**Prerequisites:** run `hf auth login` before submitting. For Weights & Biases integration, run `wandb login` or set `WANDB_API_KEY` on your machine — the key is forwarded to the job automatically.
|
||||
|
||||
**Resuming on a job.** Adding `--job.target` to a resume command runs the resume in the cloud — the same command works locally or remotely. The checkpoint repo is the source of truth, and new checkpoints continue the lineage in the same repo:
|
||||
|
||||
```bash
|
||||
# resume a Hub run on a job (its checkpoints are already on the Hub)
|
||||
lerobot-train --config_path=${HF_USER}/my_policy --resume=true --job.target=a10g-small
|
||||
|
||||
# resume a LOCAL run on a job — the checkpoint is uploaded to a private Hub repo first,
|
||||
# then the job resumes from it (a local-only dataset is uploaded the same way)
|
||||
lerobot-train \
|
||||
--config_path=outputs/train/act_so101_test/checkpoints/last/pretrained_model/train_config.json \
|
||||
--resume=true \
|
||||
--job.target=a10g-small
|
||||
```
|
||||
|
||||
Job settings come from the current command, so override `--job.target`, `--job.timeout`, etc. as needed; for the resumed run to itself be resumable later, keep `--save_checkpoint_to_hub=true`.
|
||||
|
||||
#### Upload policy checkpoints
|
||||
|
||||
Once training is done, upload the latest checkpoint with:
|
||||
@@ -612,6 +665,8 @@ hf upload ${HF_USER}/act_so101_test${CKPT} \
|
||||
|
||||
Use `lerobot-rollout` to deploy a trained policy on your robot. You can choose different strategies depending on your needs:
|
||||
|
||||
The examples below load the model from `--policy.path`. To pin a specific pushed version — useful once `--save_checkpoint_to_hub=true` has committed several checkpoints — add `--policy.pretrained_revision` with a commit hash, branch, or tag. Each pushed checkpoint is tagged with its step (e.g. `--policy.pretrained_revision=010000`), so you can recover a checkpoint by step without looking up its commit sha.
|
||||
|
||||
<hfoptions id="eval">
|
||||
<hfoption id="Base mode (no recording)">
|
||||
```bash
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import HfApi, snapshot_download
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
@@ -35,6 +36,7 @@ from lerobot.utils.constants import (
|
||||
TRAINING_STATE_DIR,
|
||||
TRAINING_STEP,
|
||||
)
|
||||
from lerobot.utils.hub import find_latest_hub_checkpoint
|
||||
from lerobot.utils.io_utils import load_json, write_json
|
||||
from lerobot.utils.random_utils import load_rng_state, save_rng_state
|
||||
|
||||
@@ -283,3 +285,61 @@ def load_fsdp_optimizer_state(model, optimizer, checkpoint_dir: Path) -> None:
|
||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, state_cfg, optim_cfg):
|
||||
sharded_osd = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=full_osd)
|
||||
optimizer.load_state_dict(sharded_osd)
|
||||
|
||||
|
||||
def push_checkpoint_to_hub(
|
||||
checkpoint_dir: Path,
|
||||
repo_id: str,
|
||||
*,
|
||||
private: bool | None = None,
|
||||
) -> None:
|
||||
"""Upload a saved checkpoint directory to the Hub under checkpoints/<name>/.
|
||||
|
||||
Called once per save step when save_checkpoint_to_hub is enabled, so a
|
||||
timed-out or crashed run still leaves recoverable checkpoints on the Hub.
|
||||
The model repo is created idempotently, and the commit is tagged with the
|
||||
checkpoint step so a checkpoint can be recovered with
|
||||
--policy.pretrained_revision=<step> instead of a commit sha.
|
||||
"""
|
||||
api = HfApi()
|
||||
api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True)
|
||||
commit = api.upload_folder(
|
||||
folder_path=str(checkpoint_dir),
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
path_in_repo=f"checkpoints/{checkpoint_dir.name}",
|
||||
commit_message=f"checkpoint {checkpoint_dir.name}",
|
||||
)
|
||||
api.create_tag(
|
||||
repo_id=repo_id,
|
||||
tag=checkpoint_dir.name,
|
||||
revision=commit.oid,
|
||||
repo_type="model",
|
||||
exist_ok=True,
|
||||
)
|
||||
|
||||
|
||||
def resolve_resume_checkpoint(repo_id: str, output_dir: Path) -> Path:
|
||||
"""Download the latest checkpoint of a Hub training repo into a local run dir.
|
||||
|
||||
The symmetric counterpart to `push_checkpoint_to_hub`: given a model repo holding
|
||||
`checkpoints/<step>/{pretrained_model,training_state}` subtrees, download the highest-numbered step
|
||||
into `output_dir/checkpoints/<step>/`, recreate the local `last` symlink, and return that local
|
||||
checkpoint dir. Used to resume training from the Hub on a machine (or HF Jobs pod) that does not
|
||||
have the original local run dir.
|
||||
"""
|
||||
latest = find_latest_hub_checkpoint(repo_id)
|
||||
if latest is None:
|
||||
raise FileNotFoundError(
|
||||
f"No checkpoint found in '{repo_id}' under '{CHECKPOINTS_DIR}/'. "
|
||||
"Was the run trained with --save_checkpoint_to_hub?"
|
||||
)
|
||||
snapshot_download(
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
allow_patterns=f"{latest}/*",
|
||||
local_dir=str(output_dir),
|
||||
)
|
||||
checkpoint_dir = output_dir / latest
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
return checkpoint_dir
|
||||
|
||||
@@ -22,7 +22,7 @@ Import them directly: ``from lerobot.configs.train import TrainPipelineConfig``
|
||||
"""
|
||||
|
||||
from .dataset import DatasetRecordConfig
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .default import DatasetConfig, EvalConfig, JobConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .recipe import MessageTurn, TrainingRecipe, load_recipe
|
||||
from .types import (
|
||||
@@ -50,6 +50,7 @@ __all__ = [
|
||||
"DatasetRecordConfig",
|
||||
"DatasetConfig",
|
||||
"EvalConfig",
|
||||
"JobConfig",
|
||||
"MessageTurn",
|
||||
"PeftConfig",
|
||||
"PreTrainedConfig",
|
||||
|
||||
@@ -123,3 +123,35 @@ class PeftConfig:
|
||||
# If None, the PEFT library defaults to alpha=8, which may dampen high-rank adapters.
|
||||
# Common values are r (alpha == rank) or 2*r.
|
||||
lora_alpha: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class JobConfig:
|
||||
# Where training runs. None (omitted) or "local" runs on this machine.
|
||||
# Any other value is an HF Jobs flavor and submits the run to HF Jobs.
|
||||
# List available flavors + pricing with `hf jobs hardware` command.
|
||||
target: str | None = None
|
||||
# Runtime image for the remote job (ignored for local runs).
|
||||
image: str = "huggingface/lerobot-gpu:latest"
|
||||
# Max wall-clock for the remote job as an HF Jobs duration string (e.g. "2h").
|
||||
# Defaults to "2d": We pass an explicit, generous cap instead. Set a smaller
|
||||
# value to fail fast, or a larger one for long runs.
|
||||
timeout: str | None = "2d"
|
||||
# Submit and exit instead of streaming the job logs in the foreground.
|
||||
detach: bool = False
|
||||
# Extra tags attached to the HF job and to any dataset this run pushes to the
|
||||
# Hub. A "lerobot" tag is always added; e.g. --job.tags '["lelab"]' adds more.
|
||||
tags: list[str] = field(default_factory=list)
|
||||
|
||||
# Two entry points to the same predicate: the staticmethod tests a raw target string
|
||||
# straight from argv (before any JobConfig exists, to decide dispatch early), while the
|
||||
# property is the ergonomic accessor for code that already holds a config instance.
|
||||
@staticmethod
|
||||
def is_remote_target(target: str | None) -> bool:
|
||||
"""True when `target` names an HF Jobs flavor rather than a local run."""
|
||||
return target not in (None, "local")
|
||||
|
||||
@property
|
||||
def is_remote(self) -> bool:
|
||||
"""True when training should run on HF Jobs rather than this machine."""
|
||||
return self.is_remote_target(self.target)
|
||||
|
||||
+100
-43
@@ -26,11 +26,12 @@ from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot import envs
|
||||
from lerobot.optim import LRSchedulerConfig, OptimizerConfig
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.constants import PRETRAINED_MODEL_DIR
|
||||
from lerobot.utils.hub import HubMixin, find_latest_hub_checkpoint
|
||||
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||
|
||||
from . import parser
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .default import DatasetConfig, EvalConfig, JobConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .rewards import RewardModelConfig
|
||||
|
||||
@@ -83,10 +84,11 @@ class TrainPipelineConfig(HubMixin):
|
||||
# with the same value for `dir` its contents will be overwritten unless you set `resume` to true.
|
||||
output_dir: Path | None = None
|
||||
job_name: str | None = None
|
||||
# Set `resume` to true to resume a previous run. In order for this to work, you will need to make sure
|
||||
# `dir` is the directory of an existing run with at least one checkpoint in it.
|
||||
# Note that when resuming a run, the default behavior is to use the configuration from the checkpoint,
|
||||
# regardless of what's provided with the training command at the time of resumption.
|
||||
# Set `resume` to true to resume a previous run. Pass `--config_path` pointing at either a local
|
||||
# checkpoint's train_config.json or a Hub repo id holding `checkpoints/<step>/` subtrees (the
|
||||
# latest checkpoint is downloaded and resumed from). Note that when resuming, the default behavior
|
||||
# is to use the configuration from the checkpoint, regardless of what's provided with the training
|
||||
# command at the time of resumption (CLI `--*` flags still override).
|
||||
resume: bool = False
|
||||
# `seed` is used for training (eg: model initialization, dataset shuffling)
|
||||
# AND for the evaluation environments.
|
||||
@@ -113,6 +115,13 @@ class TrainPipelineConfig(HubMixin):
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
peft: PeftConfig | None = None
|
||||
|
||||
# Where to run training (local default, or an HF Jobs flavor). See JobConfig.
|
||||
job: JobConfig = field(default_factory=JobConfig)
|
||||
# Push each saved checkpoint to the Hub (policy.repo_id) as it is written, not
|
||||
# just the final model (useful to monitor progress mid-run). Optional; the
|
||||
# final model is pushed regardless. Works the same locally and remotely.
|
||||
save_checkpoint_to_hub: bool = False
|
||||
|
||||
# Sample weighting configuration (e.g., for RA-BC training)
|
||||
sample_weighting: SampleWeightingConfig | None = None
|
||||
|
||||
@@ -132,10 +141,17 @@ class TrainPipelineConfig(HubMixin):
|
||||
return self.reward_model # type: ignore[return-value]
|
||||
return self.policy # type: ignore[return-value]
|
||||
|
||||
def validate(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
def _resolve_pretrained_from_cli(self) -> None:
|
||||
"""Resolve the pretrained source passed on the CLI into a loaded config.
|
||||
|
||||
The pretrained paths (`--policy.path`, `--reward_model.path`) and
|
||||
`--config_path` are only recoverable by re-reading the CLI args: draccus
|
||||
has already consumed them by the time `validate()` runs, so they are not
|
||||
reflected on `self`. Exactly one source applies, in priority order:
|
||||
reward-model path, policy path, then resume.
|
||||
"""
|
||||
reward_model_path = parser.get_path_arg("reward_model")
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
|
||||
if reward_model_path:
|
||||
cli_overrides = parser.get_cli_overrides("reward_model")
|
||||
@@ -144,31 +160,54 @@ class TrainPipelineConfig(HubMixin):
|
||||
)
|
||||
self.reward_model.pretrained_path = str(Path(reward_model_path))
|
||||
elif policy_path:
|
||||
yaml_overrides = parser.get_yaml_overrides("policy")
|
||||
cli_overrides = parser.get_cli_overrides("policy") or []
|
||||
self.policy = PreTrainedConfig.from_pretrained(
|
||||
policy_path, cli_overrides=yaml_overrides + cli_overrides
|
||||
)
|
||||
overrides = parser.get_yaml_overrides("policy") + (parser.get_cli_overrides("policy") or [])
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=overrides)
|
||||
self.policy.pretrained_path = Path(policy_path)
|
||||
elif self.resume:
|
||||
config_path = parser.parse_arg("config_path")
|
||||
if not config_path:
|
||||
raise ValueError(
|
||||
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
|
||||
)
|
||||
self._resolve_resume_checkpoint()
|
||||
|
||||
if not Path(config_path).resolve().exists():
|
||||
raise NotADirectoryError(
|
||||
f"{config_path=} is expected to be a local path. "
|
||||
"Resuming from the hub is not supported for now."
|
||||
)
|
||||
def _resolve_resume_checkpoint(self) -> None:
|
||||
"""Point the trainable config at the checkpoint named by `--config_path`.
|
||||
|
||||
`config_path` is either a local path (to a checkpoint's train_config.json or its
|
||||
pretrained_model/ dir) or a Hub repo id. For a Hub repo, the latest checkpoint is downloaded
|
||||
into a fresh local run dir and resumed from there. The download is skipped when dispatching to
|
||||
an HF Job (`job.is_remote`): the pod performs it when it runs the resume locally, and
|
||||
`submit_to_hf` resolves the source repo for the remote command.
|
||||
"""
|
||||
config_path = parser.parse_arg("config_path")
|
||||
if not config_path:
|
||||
raise ValueError(
|
||||
f"A config_path is expected when resuming a run. Please specify path to {TRAIN_CONFIG_NAME}"
|
||||
)
|
||||
|
||||
if Path(config_path).resolve().exists():
|
||||
policy_dir = Path(config_path).parent
|
||||
if self.policy is not None:
|
||||
self.policy.pretrained_path = policy_dir
|
||||
if self.reward_model is not None:
|
||||
self.reward_model.pretrained_path = str(policy_dir)
|
||||
self.checkpoint_path = policy_dir.parent
|
||||
elif self.job.is_remote:
|
||||
return
|
||||
else:
|
||||
from lerobot.common.train_utils import resolve_resume_checkpoint
|
||||
|
||||
# `self.output_dir` was loaded from the checkpoint's config and points at the original
|
||||
# run's (now-absent) local dir. Resume into a fresh local dir instead, unless the user
|
||||
# passed --output_dir explicitly.
|
||||
cli_output_dir = parser.parse_arg("output_dir")
|
||||
if cli_output_dir:
|
||||
self.output_dir = Path(cli_output_dir)
|
||||
else:
|
||||
now = dt.datetime.now()
|
||||
self.output_dir = Path("outputs/train") / f"{now:%Y-%m-%d}/{now:%H-%M-%S}_resume"
|
||||
self.checkpoint_path = resolve_resume_checkpoint(config_path, self.output_dir)
|
||||
policy_dir = self.checkpoint_path / PRETRAINED_MODEL_DIR
|
||||
|
||||
if self.policy is not None:
|
||||
self.policy.pretrained_path = policy_dir
|
||||
if self.reward_model is not None:
|
||||
self.reward_model.pretrained_path = str(policy_dir)
|
||||
|
||||
def validate(self) -> None:
|
||||
self._resolve_pretrained_from_cli()
|
||||
|
||||
if self.policy is None and self.reward_model is None:
|
||||
raise ValueError(
|
||||
@@ -208,9 +247,19 @@ class TrainPipelineConfig(HubMixin):
|
||||
self.optimizer = active_cfg.get_optimizer_preset()
|
||||
self.scheduler = active_cfg.get_scheduler_preset()
|
||||
|
||||
if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id:
|
||||
# Remote runs auto-generate the repo_id in submit_to_hf (the policy may only be
|
||||
# resolved here, from --policy.path), so don't demand it up front for them.
|
||||
if (
|
||||
hasattr(active_cfg, "push_to_hub")
|
||||
and active_cfg.push_to_hub
|
||||
and not active_cfg.repo_id
|
||||
and not self.job.is_remote
|
||||
):
|
||||
raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.")
|
||||
|
||||
if self.save_checkpoint_to_hub and not (self.policy is not None and self.policy.repo_id):
|
||||
raise ValueError("save_checkpoint_to_hub requires --policy.repo_id.")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""Keys for draccus pretrained-path loading."""
|
||||
@@ -247,22 +296,30 @@ class TrainPipelineConfig(HubMixin):
|
||||
elif Path(model_id).is_file():
|
||||
config_file = model_id
|
||||
else:
|
||||
dl_kwargs = {
|
||||
"repo_id": model_id,
|
||||
"revision": revision,
|
||||
"cache_dir": cache_dir,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"resume_download": resume_download,
|
||||
"token": token,
|
||||
"local_files_only": local_files_only,
|
||||
}
|
||||
try:
|
||||
config_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=TRAIN_CONFIG_NAME,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
config_file = hf_hub_download(filename=TRAIN_CONFIG_NAME, **dl_kwargs)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
# No root train_config.json: this is a repo of periodic checkpoints from an
|
||||
# interrupted run. Fall back to the latest checkpoint's config so the run can be
|
||||
# resumed straight from the repo with `--config_path=<repo>`.
|
||||
latest = find_latest_hub_checkpoint(model_id, token=token, revision=revision)
|
||||
if latest is None:
|
||||
raise FileNotFoundError(
|
||||
f"{TRAIN_CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
config_file = hf_hub_download(
|
||||
filename=f"{latest}/{PRETRAINED_MODEL_DIR}/{TRAIN_CONFIG_NAME}", **dl_kwargs
|
||||
)
|
||||
|
||||
cli_args = kwargs.pop("cli_args", [])
|
||||
# Legacy RA-BC migration only applies to framework-saved checkpoints (always JSON).
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .hf import submit_to_hf
|
||||
|
||||
__all__ = ["submit_to_hf"]
|
||||
@@ -0,0 +1,57 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Make a training dataset reachable from an HF Job pod.
|
||||
|
||||
The pod can't see the host's ~/.cache/huggingface/lerobot, so the dataset has to
|
||||
live on the Hub: the pod downloads it by repo_id at train time (the forwarded
|
||||
HF_TOKEN covers private datasets). A dataset already on the Hub is used as-is; a
|
||||
local-only dataset is pushed to a PRIVATE repo first (never public).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
|
||||
def ensure_dataset_available(repo_id: str, *, api: HfApi, tags: list[str] | None = None) -> None:
|
||||
"""Ensure repo_id resolves on the Hub, pushing a local-only dataset privately first.
|
||||
|
||||
`tags` are attached to the dataset only when we push it (an already-on-Hub
|
||||
dataset is left untouched). Raises RuntimeError if the dataset is neither on
|
||||
the Hub nor in the local cache.
|
||||
"""
|
||||
if api.repo_exists(repo_id, repo_type="dataset"):
|
||||
return
|
||||
|
||||
local_present = (HF_LEROBOT_HOME / repo_id / "meta" / "info.json").is_file()
|
||||
if not local_present:
|
||||
raise RuntimeError(
|
||||
f"Dataset '{repo_id}' is not in the local cache ({HF_LEROBOT_HOME}) and could not be "
|
||||
f"reached on the Hub — it may not exist, or be private and inaccessible with your "
|
||||
f"token. Record or download it first, or run `hf auth login`."
|
||||
)
|
||||
|
||||
print(f"[dataset] '{repo_id}' is local-only; pushing to a PRIVATE Hub repo...")
|
||||
# Lazy import: LeRobotDataset pulls in heavy dataset deps; defer until actually needed.
|
||||
require_package("datasets", extra="dataset")
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
LeRobotDataset(repo_id).push_to_hub(private=True, tags=tags)
|
||||
print(f"[dataset] '{repo_id}' uploaded (private). The job will download it by repo_id.")
|
||||
@@ -0,0 +1,423 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Run a lerobot training on HF Jobs (HuggingFace GPUs).
|
||||
|
||||
Ported and simplified from lelab's runners/hf_cloud.py: no UI log queue, no
|
||||
registry — just submit and stream to stdout.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import datetime as dt
|
||||
import json
|
||||
import netrc
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
from huggingface_hub import (
|
||||
HfApi,
|
||||
create_repo,
|
||||
fetch_job_logs,
|
||||
get_token,
|
||||
inspect_job,
|
||||
run_job,
|
||||
upload_file,
|
||||
)
|
||||
|
||||
from lerobot.common.train_utils import push_checkpoint_to_hub
|
||||
from lerobot.configs import parser
|
||||
from lerobot.jobs.dataset import ensure_dataset_available
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
_SLUG_RE = re.compile(r"[^a-zA-Z0-9._-]+")
|
||||
|
||||
_TERMINAL_STAGES = {"COMPLETED", "CANCELED", "ERROR", "DELETED"}
|
||||
|
||||
# huggingface_hub 1.x runs on httpx: transient HTTP/transport failures surface as
|
||||
# httpx.HTTPError and socket-level errors as OSError. Catching only these keeps real
|
||||
# bugs (TypeError, AttributeError, ...) from being silently retried or counted as
|
||||
# job failures.
|
||||
_TRANSIENT_NET_ERRORS = (OSError, httpx.HTTPError)
|
||||
|
||||
# Always attached to remote jobs and pushed datasets so LeRobot-originated work
|
||||
# is identifiable on the Hub; callers (e.g. LeLab) add their own via --job.tags.
|
||||
LEROBOT_TAG = "lerobot"
|
||||
|
||||
|
||||
def resolve_job_tags(extra: list[str] | None) -> list[str]:
|
||||
"""Return the tag list for a run: the lerobot tag plus any extras, deduped, order-stable."""
|
||||
tags = [LEROBOT_TAG, *(extra or [])]
|
||||
seen: set[str] = set()
|
||||
return [t for t in tags if not (t in seen or seen.add(t))]
|
||||
|
||||
|
||||
def resolve_wandb_api_key() -> str | None:
|
||||
"""Host's wandb key for forwarding to the job: $WANDB_API_KEY, else ~/.netrc."""
|
||||
key = os.environ.get("WANDB_API_KEY")
|
||||
if key:
|
||||
return key
|
||||
try:
|
||||
rc = netrc.netrc()
|
||||
except (FileNotFoundError, netrc.NetrcParseError, OSError):
|
||||
return None
|
||||
auth = rc.authenticators("api.wandb.ai")
|
||||
if auth is None:
|
||||
return None
|
||||
_login, _account, password = auth
|
||||
return password or None
|
||||
|
||||
|
||||
def build_repo_id(username: str, job_name: str, now: dt.datetime) -> str:
|
||||
"""Generate the model repo id for a remote run: <user>/<job_name>_<timestamp>."""
|
||||
slug = _SLUG_RE.sub("-", job_name).strip("-") or "train"
|
||||
stamp = now.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
return f"{username}/{slug}_{stamp}"
|
||||
|
||||
|
||||
def build_remote_config_file(cfg, repo_id: str, dest: Path, tags: list[str] | None = None) -> Path:
|
||||
"""Write a train_config.json for the pod, with remote overrides applied.
|
||||
|
||||
The pod runs `lerobot-train --config_path=<dest>` and downloads the dataset
|
||||
by repo_id into its own cache. Client-only fields are stripped so the config
|
||||
is accepted by the trainer image: `job` (pure client orchestration) is always
|
||||
removed, and `save_checkpoint_to_hub` is removed unless explicitly enabled —
|
||||
older lerobot images reject unknown keys, so the default keeps the config
|
||||
compatible with the released `lerobot-gpu` image. `tags` are merged into
|
||||
policy.tags so the trained model the pod pushes carries them too.
|
||||
"""
|
||||
remote = copy.deepcopy(cfg)
|
||||
remote.policy.push_to_hub = True
|
||||
remote.policy.repo_id = repo_id
|
||||
# Don't pin the client's resolved device (e.g. "mps"); let the pod auto-detect its GPU.
|
||||
remote.policy.device = None
|
||||
# Drop any host-local dataset root; the pod resolves the dataset by repo_id.
|
||||
remote.dataset.root = None
|
||||
if tags:
|
||||
existing = list(remote.policy.tags or [])
|
||||
remote.policy.tags = existing + [t for t in tags if t not in existing]
|
||||
|
||||
# Encode to the canonical, pod-parseable dict, then drop the keys the released
|
||||
# trainer image doesn't know about.
|
||||
data = remote.to_dict()
|
||||
data.pop("job", None)
|
||||
if not remote.save_checkpoint_to_hub:
|
||||
data.pop("save_checkpoint_to_hub", None)
|
||||
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_text(json.dumps(data, indent=4))
|
||||
return dest
|
||||
|
||||
|
||||
def _stage_config_on_hub(cfg, repo_id: str, token: str, tags: list[str] | None = None) -> str:
|
||||
"""Upload train_config.json to the model repo and return the repo_id for --config_path."""
|
||||
create_repo(repo_id, repo_type="model", private=True, exist_ok=True, token=token)
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
config_path = build_remote_config_file(cfg, repo_id, Path(tmp) / "train_config.json", tags=tags)
|
||||
upload_file(
|
||||
path_or_fileobj=config_path,
|
||||
path_in_repo="train_config.json",
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
token=token,
|
||||
)
|
||||
return repo_id
|
||||
|
||||
|
||||
def _tail_logs(
|
||||
job_id: str,
|
||||
done: threading.Event,
|
||||
success_marker: str | None = None,
|
||||
success_event: threading.Event | None = None,
|
||||
) -> None:
|
||||
"""Stream job logs to stdout, reconnecting on dropped streams until done is set.
|
||||
|
||||
Each reconnect re-fetches the full buffered log, so we track how many lines
|
||||
were already printed and skip them — otherwise a fast-failing job's traceback
|
||||
gets reprinted on every reconnect.
|
||||
|
||||
When `success_marker` appears in a line, set `success_event` and `done` so the
|
||||
caller can finish as soon as the trained model lands on the Hub, rather than
|
||||
waiting out the platform's post-run finalization (which can add ~30s).
|
||||
"""
|
||||
printed = 0
|
||||
while not done.is_set():
|
||||
try:
|
||||
seen = 0
|
||||
for line in fetch_job_logs(job_id=job_id, follow=True):
|
||||
seen += 1
|
||||
if seen <= printed:
|
||||
continue # already shown on a previous connection
|
||||
printed = seen
|
||||
# fetch_job_logs yields SSE data without trailing newlines, so add one
|
||||
# per entry — otherwise all log lines concatenate onto a single line.
|
||||
print(line.rstrip("\n"), flush=True)
|
||||
if success_marker and success_event is not None and success_marker in line:
|
||||
success_event.set()
|
||||
done.set()
|
||||
return
|
||||
if done.is_set():
|
||||
return
|
||||
# Stream closed cleanly. Wait a moment so the status poller can mark
|
||||
# the job terminal before we reconnect (avoids re-tailing the buffer).
|
||||
if done.wait(3):
|
||||
return
|
||||
except _TRANSIENT_NET_ERRORS:
|
||||
if done.wait(2):
|
||||
return
|
||||
|
||||
|
||||
def _poll_until_done(
|
||||
job_id: str,
|
||||
done: threading.Event,
|
||||
poll_interval: float = 5.0,
|
||||
status_holder: dict | None = None,
|
||||
max_failures: int = 6,
|
||||
) -> str | None:
|
||||
"""Poll inspect_job until a terminal stage or until `done` is set.
|
||||
|
||||
Returns the terminal stage string, or None if `done` was set first (detach)
|
||||
or after `max_failures` consecutive inspect_job errors. When a terminal stage
|
||||
is reached and `status_holder` is given, records `status_holder["message"]`
|
||||
(the platform's status message, e.g. "Job timeout").
|
||||
"""
|
||||
failures = 0
|
||||
while not done.is_set():
|
||||
try:
|
||||
info = inspect_job(job_id=job_id)
|
||||
failures = 0
|
||||
stage = info.status.stage.value
|
||||
if stage in _TERMINAL_STAGES:
|
||||
if status_holder is not None:
|
||||
status_holder["message"] = getattr(info.status, "message", None)
|
||||
done.set()
|
||||
return stage
|
||||
except _TRANSIENT_NET_ERRORS:
|
||||
failures += 1
|
||||
if failures >= max_failures:
|
||||
done.set()
|
||||
return None
|
||||
done.wait(poll_interval)
|
||||
return None
|
||||
|
||||
|
||||
def _pod_forwarded_args(
|
||||
argv: list[str], drop_names: tuple[str, ...] = (), drop_prefixes: tuple[str, ...] = ()
|
||||
) -> list[str]:
|
||||
"""User CLI overrides to replay on the pod, minus flags the submitter sets itself.
|
||||
|
||||
Handles both `--name=value` and `--name value` forms. Forwarding the user's overrides (e.g.
|
||||
`--steps`, `--save_checkpoint_to_hub`) makes a remote resume behave like the same local command.
|
||||
"""
|
||||
out: list[str] = []
|
||||
skip_next = False
|
||||
for i, tok in enumerate(argv):
|
||||
if skip_next:
|
||||
skip_next = False
|
||||
continue
|
||||
name = tok.split("=", 1)[0]
|
||||
if name in drop_names or any(name.startswith(p) for p in drop_prefixes):
|
||||
if "=" not in tok and i + 1 < len(argv) and not argv[i + 1].startswith("--"):
|
||||
skip_next = True # also drop the space-separated value
|
||||
continue
|
||||
out.append(tok)
|
||||
return out
|
||||
|
||||
|
||||
def _build_resume_job(cfg: TrainPipelineConfig, username: str) -> tuple[str, list[str]]:
|
||||
"""Resolve the model repo and pod command to resume a run on a job.
|
||||
|
||||
A Hub `config_path` is resumed from directly: its checkpoint config already targets that repo,
|
||||
so new checkpoints continue the lineage there. A local `config_path` has its checkpoint uploaded
|
||||
to a new PRIVATE repo first, and the resumed run is forced to push back to it. The pod command
|
||||
always carries `--job.target=local` so the checkpoint's saved `job.target` can't make the pod
|
||||
re-dispatch itself.
|
||||
"""
|
||||
config_path = parser.parse_arg("config_path")
|
||||
forwarded = _pod_forwarded_args(
|
||||
sys.argv[1:],
|
||||
drop_names=("--config_path", "--policy.repo_id", "--policy.push_to_hub"),
|
||||
drop_prefixes=("--job.",),
|
||||
)
|
||||
|
||||
if Path(config_path).exists():
|
||||
# Local checkpoint: stage it on the Hub so the pod can resume from it, and push back there.
|
||||
# Resolve so a `last` symlink uploads under its real step name (digit), which the pod's
|
||||
# latest-checkpoint lookup keys on.
|
||||
checkpoint_dir = Path(cfg.checkpoint_path).resolve()
|
||||
source_repo = build_repo_id(username, cfg.job_name or "train", dt.datetime.now(dt.UTC))
|
||||
push_checkpoint_to_hub(checkpoint_dir, source_repo, private=True)
|
||||
extra = [f"--policy.repo_id={source_repo}", "--policy.push_to_hub=true"]
|
||||
else:
|
||||
source_repo = config_path
|
||||
extra = []
|
||||
|
||||
command = [
|
||||
"lerobot-train",
|
||||
*forwarded,
|
||||
f"--config_path={source_repo}",
|
||||
"--job.target=local",
|
||||
*extra,
|
||||
]
|
||||
return source_repo, command
|
||||
|
||||
|
||||
def submit_to_hf(cfg: TrainPipelineConfig) -> None:
|
||||
"""Submit a training job to HF Jobs infrastructure.
|
||||
|
||||
Validates cfg, resolves credentials, ensures the dataset is on the Hub, then either stages a
|
||||
sanitized config (fresh run) or resumes from a checkpoint repo, submits the job, and tails logs
|
||||
until completion or detaches immediately. Ctrl-C detaches without cancelling the remote job.
|
||||
"""
|
||||
token = get_token()
|
||||
if not token:
|
||||
raise RuntimeError("Not logged in to Hugging Face. Run `hf auth login` first.")
|
||||
|
||||
api = HfApi(token=token)
|
||||
user_info = api.whoami(token=token)
|
||||
username = user_info["name"]
|
||||
|
||||
now = dt.datetime.now(dt.UTC)
|
||||
fresh_repo_id: str | None = None
|
||||
if not cfg.resume:
|
||||
# Resolve the model repo and mark it for push BEFORE validate(): validate() requires repo_id
|
||||
# to be set whenever push_to_hub is True. (A resume reuses the checkpoint's repo instead.)
|
||||
if cfg.policy is not None:
|
||||
base_name = cfg.job_name or cfg.policy.type
|
||||
fresh_repo_id = cfg.policy.repo_id or build_repo_id(username, base_name, now)
|
||||
cfg.policy.repo_id = fresh_repo_id
|
||||
cfg.policy.push_to_hub = True
|
||||
else:
|
||||
# Path-based policy is resolved inside validate(); fall back to a generic slug.
|
||||
fresh_repo_id = build_repo_id(username, cfg.job_name or "train", now)
|
||||
|
||||
cfg.validate()
|
||||
|
||||
if cfg.is_reward_model_training:
|
||||
raise ValueError(
|
||||
"Remote training via --job.target only supports policy training, not reward models. "
|
||||
"Run reward-model training locally."
|
||||
)
|
||||
|
||||
secrets: dict[str, str] = {"HF_TOKEN": token}
|
||||
if cfg.wandb.enable:
|
||||
wandb_key = resolve_wandb_api_key()
|
||||
if wandb_key is None:
|
||||
raise ValueError(
|
||||
"wandb is enabled but no WANDB_API_KEY found. "
|
||||
"Set it via `export WANDB_API_KEY=...` or add it to ~/.netrc."
|
||||
)
|
||||
secrets["WANDB_API_KEY"] = wandb_key
|
||||
|
||||
tags = resolve_job_tags(cfg.job.tags)
|
||||
# The dataset must be reachable from the pod for both fresh and resumed runs; a local-only
|
||||
# dataset is pushed PRIVATE here. Hoisted before the resume/fresh branch since it applies to both.
|
||||
ensure_dataset_available(cfg.dataset.repo_id, api=api, tags=tags)
|
||||
|
||||
if cfg.resume:
|
||||
repo_id, command = _build_resume_job(cfg, username)
|
||||
else:
|
||||
config_repo_id = _stage_config_on_hub(cfg, fresh_repo_id, token, tags=tags)
|
||||
repo_id = fresh_repo_id
|
||||
command = ["lerobot-train", f"--config_path={config_repo_id}"]
|
||||
|
||||
print(f"Submitting job to HF Jobs (flavor={cfg.job.target}, image={cfg.job.image}) ...")
|
||||
job_info = run_job(
|
||||
image=cfg.job.image,
|
||||
command=command,
|
||||
flavor=cfg.job.target,
|
||||
secrets=secrets,
|
||||
timeout=cfg.job.timeout,
|
||||
# HF Jobs labels are key/value; expose each tag as a queryable label.
|
||||
labels=dict.fromkeys(tags, "true"),
|
||||
)
|
||||
job_id = job_info.id
|
||||
job_url = getattr(job_info, "url", None)
|
||||
print(f"Job submitted: {job_id}")
|
||||
if job_url:
|
||||
print(f" Job page: {job_url}")
|
||||
print(f" Model repo: https://huggingface.co/{repo_id}")
|
||||
print(f" Monitor: hf jobs logs {job_id}")
|
||||
print(f" Cancel: hf jobs cancel {job_id}")
|
||||
|
||||
if cfg.job.detach:
|
||||
return
|
||||
|
||||
done = threading.Event()
|
||||
detached = threading.Event()
|
||||
pushed_ok = threading.Event()
|
||||
stage_holder: dict[str, str | None] = {}
|
||||
|
||||
def _poll() -> None:
|
||||
stage_holder["stage"] = _poll_until_done(job_id, done, status_holder=stage_holder)
|
||||
|
||||
poll_thread = threading.Thread(target=_poll, daemon=True)
|
||||
poll_thread.start()
|
||||
# Finish as soon as the model is pushed, rather than waiting out the platform's
|
||||
# post-run finalization before the job stage flips to COMPLETED. This matches the
|
||||
# exact log line emitted by PreTrainedPolicy.push_model_to_hub — the two must stay
|
||||
# in sync. If it ever stops matching we just fall back to stage-based completion
|
||||
# (~30s slower), so the contract is an optimization, not a correctness requirement.
|
||||
success_marker = f"Model pushed to https://huggingface.co/{repo_id}"
|
||||
log_thread = threading.Thread(
|
||||
target=_tail_logs, args=(job_id, done, success_marker, pushed_ok), daemon=True
|
||||
)
|
||||
log_thread.start()
|
||||
|
||||
def _detach(sig, frame):
|
||||
detached.set()
|
||||
done.set()
|
||||
print("\nDetached. Job is still running.")
|
||||
print(f" Monitor: hf jobs logs {job_id}")
|
||||
print(f" Cancel: hf jobs cancel {job_id}")
|
||||
|
||||
# signal.signal only works on the main thread; when called from a worker thread
|
||||
# (e.g. an orchestration framework) skip the Ctrl-C-detaches-instead-of-cancels
|
||||
# handler rather than crashing with ValueError.
|
||||
install_sigint = threading.current_thread() is threading.main_thread()
|
||||
original_sigint = signal.getsignal(signal.SIGINT) if install_sigint else None
|
||||
if install_sigint:
|
||||
signal.signal(signal.SIGINT, _detach)
|
||||
try:
|
||||
# Timeout-based join so SIGINT is delivered to the main thread promptly.
|
||||
while poll_thread.is_alive():
|
||||
poll_thread.join(timeout=0.5)
|
||||
log_thread.join(timeout=5)
|
||||
finally:
|
||||
if install_sigint:
|
||||
signal.signal(signal.SIGINT, original_sigint)
|
||||
|
||||
if detached.is_set():
|
||||
return
|
||||
|
||||
if pushed_ok.is_set():
|
||||
print(f"\nTraining complete — model pushed to https://huggingface.co/{repo_id}")
|
||||
return
|
||||
|
||||
stage = stage_holder.get("stage")
|
||||
if stage != "COMPLETED":
|
||||
message = stage_holder.get("message")
|
||||
detail = f" ({message})" if message else ""
|
||||
raise RuntimeError(
|
||||
f"Job {job_id} ended with stage={stage}{detail}. Check logs: hf jobs logs {job_id}"
|
||||
)
|
||||
@@ -340,6 +340,9 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
ignore_patterns=["*.tmp", "*.log"],
|
||||
)
|
||||
|
||||
# Contract: lerobot.jobs.hf.submit_to_hf watches for this exact
|
||||
# "Model pushed to <url>" line to end a remote run early. Keep the wording
|
||||
# and URL format in sync (it falls back to status polling if they drift).
|
||||
logging.info(f"Model pushed to {commit_info.repo_url.url}")
|
||||
|
||||
def generate_model_card(
|
||||
|
||||
@@ -41,6 +41,7 @@ from lerobot.common.train_utils import (
|
||||
load_training_batch_size,
|
||||
load_training_num_processes,
|
||||
load_training_state,
|
||||
push_checkpoint_to_hub,
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
@@ -187,6 +188,11 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
||||
accelerator: Optional Accelerator instance. If None, one will be created automatically.
|
||||
"""
|
||||
if cfg.job.is_remote:
|
||||
from lerobot.jobs import submit_to_hf
|
||||
|
||||
return submit_to_hf(cfg)
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("accelerate", extra="training")
|
||||
@@ -597,6 +603,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
optim_state_dict=optim_state_dict,
|
||||
)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if cfg.save_checkpoint_to_hub:
|
||||
push_checkpoint_to_hub(
|
||||
checkpoint_dir,
|
||||
cfg.policy.repo_id,
|
||||
private=cfg.policy.private,
|
||||
)
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
|
||||
@@ -677,8 +689,29 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def _remote_target_in_argv() -> bool:
|
||||
"""True when the CLI requests a remote HF Jobs run (--job.target=<non-local>)."""
|
||||
import sys
|
||||
|
||||
from lerobot.configs import JobConfig
|
||||
|
||||
target = None
|
||||
args = sys.argv[1:]
|
||||
for i, tok in enumerate(args):
|
||||
if tok == "--job.target" and i + 1 < len(args):
|
||||
target = args[i + 1]
|
||||
elif tok.startswith("--job.target="):
|
||||
target = tok.split("=", 1)[1]
|
||||
return JobConfig.is_remote_target(target)
|
||||
|
||||
|
||||
def main():
|
||||
register_third_party_plugins()
|
||||
if _remote_target_in_argv():
|
||||
# The policy device is resolved on the remote pod, not here, so silence the
|
||||
# client-side "Device '...' is not available" warning PreTrainedConfig emits
|
||||
# while parsing the config (it fires before train() can dispatch remotely).
|
||||
logging.getLogger("lerobot.configs.policies").setLevel(logging.ERROR)
|
||||
train()
|
||||
|
||||
|
||||
|
||||
@@ -20,9 +20,33 @@ from typing import Any, TypeVar
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from lerobot.utils.constants import CHECKPOINTS_DIR
|
||||
|
||||
T = TypeVar("T", bound="HubMixin")
|
||||
|
||||
|
||||
def find_latest_hub_checkpoint(
|
||||
repo_id: str,
|
||||
*,
|
||||
token: str | bool | None = None,
|
||||
revision: str | None = None,
|
||||
) -> str | None:
|
||||
"""Repo-relative path of the most recent checkpoint in a training repo.
|
||||
|
||||
Training runs push checkpoints to ``checkpoints/<step>/`` (see
|
||||
``push_checkpoint_to_hub``). This lists those step dirs and returns
|
||||
``checkpoints/<highest-step>``, or ``None`` if the repo has no checkpoints.
|
||||
"""
|
||||
files = HfApi().list_repo_files(repo_id=repo_id, repo_type="model", revision=revision, token=token)
|
||||
prefix = f"{CHECKPOINTS_DIR}/"
|
||||
steps = {
|
||||
name for f in files if f.startswith(prefix) and (name := f[len(prefix) :].split("/", 1)[0]).isdigit()
|
||||
}
|
||||
if not steps:
|
||||
return None
|
||||
return f"{CHECKPOINTS_DIR}/{max(steps, key=int)}"
|
||||
|
||||
|
||||
class HubMixin:
|
||||
"""
|
||||
A Mixin containing the functionality to push an object to the hub.
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
import lerobot.configs.train as tc
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
class _FakeHTTPError(tc.HfHubHTTPError):
|
||||
"""HfHubHTTPError that can be raised without a real HTTP response object."""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
def test_from_pretrained_falls_back_to_latest_checkpoint_config(tmp_path, monkeypatch):
|
||||
"""A Hub repo with no root train_config.json (an interrupted run that only pushed
|
||||
checkpoints/) resolves via the latest checkpoint's config."""
|
||||
# A real train_config.json written by save_pretrained, to be returned by the fallback.
|
||||
parsed = tc.draccus.parse(TrainPipelineConfig, args=["--dataset.repo_id", "u/d"])
|
||||
cfg_file = tmp_path / "train_config.json"
|
||||
parsed._save_pretrained(tmp_path)
|
||||
assert cfg_file.is_file()
|
||||
|
||||
calls = []
|
||||
|
||||
def fake_hf_hub_download(filename=None, **kwargs):
|
||||
calls.append(filename)
|
||||
if filename == "train_config.json":
|
||||
raise _FakeHTTPError() # no root config
|
||||
if filename == "checkpoints/000010/pretrained_model/train_config.json":
|
||||
return str(cfg_file)
|
||||
raise AssertionError(f"unexpected filename {filename}")
|
||||
|
||||
monkeypatch.setattr(tc, "hf_hub_download", fake_hf_hub_download)
|
||||
monkeypatch.setattr(
|
||||
tc, "find_latest_hub_checkpoint", lambda repo_id, token=None, revision=None: "checkpoints/000010"
|
||||
)
|
||||
|
||||
loaded = TrainPipelineConfig.from_pretrained("user/interrupted-run")
|
||||
assert loaded.dataset.repo_id == "u/d"
|
||||
# Tried the root config first, then fell back to the latest checkpoint's config.
|
||||
assert calls == ["train_config.json", "checkpoints/000010/pretrained_model/train_config.json"]
|
||||
|
||||
|
||||
def test_from_pretrained_raises_when_no_root_config_and_no_checkpoints(monkeypatch):
|
||||
"""No root config AND no checkpoints → a clear FileNotFoundError, not the raw HTTP error."""
|
||||
|
||||
def fake_hf_hub_download(filename=None, **kwargs):
|
||||
raise _FakeHTTPError()
|
||||
|
||||
monkeypatch.setattr(tc, "hf_hub_download", fake_hf_hub_download)
|
||||
monkeypatch.setattr(tc, "find_latest_hub_checkpoint", lambda repo_id, token=None, revision=None: None)
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="train_config.json not found"):
|
||||
TrainPipelineConfig.from_pretrained("user/empty-repo")
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Importing concrete policy configs registers their draccus `--policy.type`
|
||||
# choices (e.g. "act") so tests can parse them.
|
||||
from lerobot.policies.act.configuration_act import ACTConfig # noqa: F401
|
||||
@@ -0,0 +1,69 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.jobs.dataset import ensure_dataset_available
|
||||
|
||||
|
||||
def _api_with_dataset(exists: bool):
|
||||
api = MagicMock()
|
||||
api.repo_exists.return_value = exists
|
||||
return api
|
||||
|
||||
|
||||
def _make_local_cache(tmp_path, repo_id: str) -> None:
|
||||
"""Create the minimal local-cache layout that ensure_dataset_available checks."""
|
||||
info = tmp_path / repo_id / "meta" / "info.json"
|
||||
info.parent.mkdir(parents=True)
|
||||
info.write_text("{}")
|
||||
|
||||
|
||||
# Branch 1: dataset already on Hub → no push, no error (pod downloads by repo_id).
|
||||
def test_dataset_already_on_hub_is_noop():
|
||||
api = _api_with_dataset(True)
|
||||
assert ensure_dataset_available("user/ds", api=api) is None
|
||||
api.repo_exists.assert_called_once_with("user/ds", repo_type="dataset")
|
||||
|
||||
|
||||
# Branch 2: not on Hub but present locally → always push privately.
|
||||
def test_dataset_local_only_uploads_privately(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("lerobot.jobs.dataset.HF_LEROBOT_HOME", tmp_path)
|
||||
_make_local_cache(tmp_path, "user/ds")
|
||||
|
||||
api = _api_with_dataset(False)
|
||||
mock_ds_cls = MagicMock()
|
||||
fake_datasets_module = MagicMock()
|
||||
fake_datasets_module.LeRobotDataset = mock_ds_cls
|
||||
monkeypatch.setitem(sys.modules, "lerobot.datasets", fake_datasets_module)
|
||||
# The `datasets` extra isn't installed in the base test env; skip the import guard.
|
||||
monkeypatch.setattr("lerobot.jobs.dataset.require_package", lambda *a, **k: None)
|
||||
|
||||
assert ensure_dataset_available("user/ds", api=api, tags=["lerobot", "lelab"]) is None
|
||||
|
||||
mock_ds_cls.assert_called_once_with("user/ds")
|
||||
mock_ds_cls.return_value.push_to_hub.assert_called_once_with(private=True, tags=["lerobot", "lelab"])
|
||||
|
||||
|
||||
# Branch 3: not on Hub, NOT in local cache → RuntimeError.
|
||||
def test_dataset_neither_on_hub_nor_local_raises(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr("lerobot.jobs.dataset.HF_LEROBOT_HOME", tmp_path)
|
||||
# tmp_path is empty — no local cache.
|
||||
|
||||
api = _api_with_dataset(False)
|
||||
with pytest.raises(RuntimeError, match="not in the local cache"):
|
||||
ensure_dataset_available("user/ds", api=api)
|
||||
@@ -0,0 +1,464 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime as dt
|
||||
import json
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
|
||||
import draccus
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.jobs.hf import (
|
||||
_poll_until_done,
|
||||
build_remote_config_file,
|
||||
build_repo_id,
|
||||
resolve_job_tags,
|
||||
resolve_wandb_api_key,
|
||||
submit_to_hf,
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_job_tags_always_includes_lerobot_and_dedups():
|
||||
assert resolve_job_tags(None) == ["lerobot"]
|
||||
assert resolve_job_tags([]) == ["lerobot"]
|
||||
assert resolve_job_tags(["lelab"]) == ["lerobot", "lelab"]
|
||||
# lerobot isn't duplicated if passed explicitly; order is stable.
|
||||
assert resolve_job_tags(["lelab", "lerobot", "lelab"]) == ["lerobot", "lelab"]
|
||||
|
||||
|
||||
def _fake_inspect(stage_value):
|
||||
return lambda job_id: SimpleNamespace(status=SimpleNamespace(stage=SimpleNamespace(value=stage_value)))
|
||||
|
||||
|
||||
def test_poll_until_done_returns_terminal_stage(monkeypatch):
|
||||
monkeypatch.setattr("lerobot.jobs.hf.inspect_job", _fake_inspect("COMPLETED"))
|
||||
done = threading.Event()
|
||||
assert _poll_until_done("j", done, poll_interval=0.01) == "COMPLETED"
|
||||
assert done.is_set()
|
||||
|
||||
|
||||
def test_poll_until_done_exits_when_done_already_set(monkeypatch):
|
||||
# Non-terminal forever; with done pre-set the loop must not block and returns None.
|
||||
monkeypatch.setattr("lerobot.jobs.hf.inspect_job", _fake_inspect("RUNNING"))
|
||||
done = threading.Event()
|
||||
done.set()
|
||||
assert _poll_until_done("j", done, poll_interval=0.01) is None
|
||||
|
||||
|
||||
def test_poll_until_done_gives_up_after_repeated_network_failures(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf.inspect_job", lambda job_id: (_ for _ in ()).throw(httpx.ConnectError("boom"))
|
||||
)
|
||||
done = threading.Event()
|
||||
result = _poll_until_done("j", done, poll_interval=0.001, max_failures=3)
|
||||
assert result is None
|
||||
assert done.is_set()
|
||||
|
||||
|
||||
def test_poll_until_done_propagates_programming_errors(monkeypatch):
|
||||
"""A bug (e.g. TypeError) must surface, not be silently retried as a transient failure."""
|
||||
monkeypatch.setattr("lerobot.jobs.hf.inspect_job", lambda job_id: (_ for _ in ()).throw(TypeError("bug")))
|
||||
done = threading.Event()
|
||||
with pytest.raises(TypeError):
|
||||
_poll_until_done("j", done, poll_interval=0.001, max_failures=3)
|
||||
|
||||
|
||||
def test_resolve_wandb_key_from_env(monkeypatch):
|
||||
monkeypatch.setenv("WANDB_API_KEY", "abc123")
|
||||
assert resolve_wandb_api_key() == "abc123"
|
||||
|
||||
|
||||
def test_resolve_wandb_key_missing(monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("WANDB_API_KEY", raising=False)
|
||||
monkeypatch.setenv("HOME", str(tmp_path)) # no ~/.netrc here
|
||||
monkeypatch.setattr("netrc.netrc", lambda *a, **k: (_ for _ in ()).throw(FileNotFoundError()))
|
||||
assert resolve_wandb_api_key() is None
|
||||
|
||||
|
||||
def test_resolve_wandb_key_from_netrc(monkeypatch):
|
||||
# No env var → fall back to the wandb credentials in ~/.netrc.
|
||||
monkeypatch.delenv("WANDB_API_KEY", raising=False)
|
||||
|
||||
class _FakeNetrc:
|
||||
def authenticators(self, host):
|
||||
assert host == "api.wandb.ai"
|
||||
return ("login", "account", "netrc-secret")
|
||||
|
||||
monkeypatch.setattr("netrc.netrc", lambda *a, **k: _FakeNetrc())
|
||||
assert resolve_wandb_api_key() == "netrc-secret"
|
||||
|
||||
|
||||
def test_resolve_wandb_key_netrc_without_wandb_entry(monkeypatch):
|
||||
# ~/.netrc exists but has no api.wandb.ai entry → None.
|
||||
monkeypatch.delenv("WANDB_API_KEY", raising=False)
|
||||
|
||||
class _FakeNetrc:
|
||||
def authenticators(self, host):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("netrc.netrc", lambda *a, **k: _FakeNetrc())
|
||||
assert resolve_wandb_api_key() is None
|
||||
|
||||
|
||||
def test_build_repo_id_sanitizes_and_timestamps():
|
||||
now = dt.datetime(2026, 6, 19, 10, 22, 3)
|
||||
assert build_repo_id("alice", "act", now) == "alice/act_2026-06-19_10-22-03"
|
||||
# Runs of illegal characters collapse to a single dash; edges are trimmed.
|
||||
assert build_repo_id("alice", "my cool/run!!", now) == "alice/my-cool-run_2026-06-19_10-22-03"
|
||||
# A name with nothing usable falls back to "train".
|
||||
assert build_repo_id("alice", "///", now) == "alice/train_2026-06-19_10-22-03"
|
||||
|
||||
|
||||
def _minimal_cfg():
|
||||
return draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
|
||||
|
||||
def test_validate_skips_repo_id_check_for_remote():
|
||||
"""Remote runs auto-assign repo_id in submit_to_hf, so validate() must not demand it up front."""
|
||||
cfg = _minimal_cfg() # remote target, push_to_hub default True, no explicit repo_id
|
||||
assert cfg.policy.repo_id is None
|
||||
cfg.validate() # must not raise
|
||||
|
||||
|
||||
def test_validate_requires_repo_id_for_local_push():
|
||||
"""Local runs that push to the Hub still need an explicit repo_id."""
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act"],
|
||||
)
|
||||
with pytest.raises(ValueError, match="repo_id"):
|
||||
cfg.validate()
|
||||
|
||||
|
||||
def test_build_remote_config_applies_overrides(tmp_path):
|
||||
cfg = _minimal_cfg()
|
||||
dest = tmp_path / "train_config.json"
|
||||
out = build_remote_config_file(cfg, "u/run", dest)
|
||||
assert out == dest
|
||||
data = json.loads(dest.read_text())
|
||||
# `job` is client-only orchestration and must be stripped for the pod.
|
||||
assert "job" not in data
|
||||
# save_checkpoint_to_hub defaults off → omitted so older images accept the config.
|
||||
assert "save_checkpoint_to_hub" not in data
|
||||
assert data["policy"]["push_to_hub"] is True
|
||||
assert data["policy"]["repo_id"] == "u/run"
|
||||
assert data["policy"]["device"] is None # pod auto-detects its GPU
|
||||
assert data["dataset"]["root"] is None # pod resolves the dataset by repo_id
|
||||
# the caller's cfg must be left untouched (function works on a deep copy)
|
||||
assert cfg.job.target == "a10g-small"
|
||||
assert cfg.save_checkpoint_to_hub is False
|
||||
|
||||
|
||||
def test_build_remote_config_includes_checkpoint_flag_when_enabled(tmp_path):
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=[
|
||||
"--dataset.repo_id",
|
||||
"u/d",
|
||||
"--policy.type",
|
||||
"act",
|
||||
"--job.target",
|
||||
"a10g-small",
|
||||
"--save_checkpoint_to_hub",
|
||||
"true",
|
||||
],
|
||||
)
|
||||
dest = tmp_path / "train_config.json"
|
||||
build_remote_config_file(cfg, "u/run", dest)
|
||||
data = json.loads(dest.read_text())
|
||||
# explicitly enabled → kept in the config (requires a matching trainer image).
|
||||
assert data["save_checkpoint_to_hub"] is True
|
||||
assert "job" not in data
|
||||
|
||||
|
||||
def test_build_remote_config_merges_tags_into_policy(tmp_path):
|
||||
cfg = _minimal_cfg()
|
||||
dest = tmp_path / "train_config.json"
|
||||
build_remote_config_file(cfg, "u/run", dest, tags=["lerobot", "lelab"])
|
||||
data = json.loads(dest.read_text())
|
||||
# tags propagate to the model the pod pushes.
|
||||
assert data["policy"]["tags"] == ["lerobot", "lelab"]
|
||||
|
||||
|
||||
def test_build_remote_config_merges_tags_without_duplicating(tmp_path):
|
||||
cfg = _minimal_cfg()
|
||||
cfg.policy.tags = ["existing", "lerobot"]
|
||||
dest = tmp_path / "train_config.json"
|
||||
build_remote_config_file(cfg, "u/run", dest, tags=["lerobot", "lelab"])
|
||||
data = json.loads(dest.read_text())
|
||||
# pre-existing policy tags are kept; only genuinely-new tags are appended (no dup "lerobot").
|
||||
assert data["policy"]["tags"] == ["existing", "lerobot", "lelab"]
|
||||
|
||||
|
||||
def test_submit_requires_login(monkeypatch):
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: None)
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="hf auth login"):
|
||||
submit_to_hf(cfg)
|
||||
|
||||
|
||||
def test_submit_passes_validation_and_submits(monkeypatch):
|
||||
"""A type-based policy with no explicit repo_id is auto-assigned one and submitted."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Patch get_token
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
# Patch HfApi so whoami returns alice
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi)
|
||||
|
||||
# ensure_dataset_available returns None; patch it out so no Hub access happens
|
||||
# (hf.py imports it at module level, so patch it on lerobot.jobs.hf).
|
||||
monkeypatch.setattr("lerobot.jobs.hf.ensure_dataset_available", lambda *a, **kw: None)
|
||||
|
||||
# Patch _stage_config_on_hub to skip network
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf._stage_config_on_hub",
|
||||
lambda cfg, repo_id, token, tags=None: repo_id,
|
||||
)
|
||||
|
||||
# Patch run_job to return a fake job
|
||||
fake_job = MagicMock()
|
||||
fake_job.id = "job-123"
|
||||
run_job_calls = []
|
||||
|
||||
def fake_run_job(**kwargs):
|
||||
run_job_calls.append(kwargs)
|
||||
return fake_job
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.run_job", fake_run_job)
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=[
|
||||
"--dataset.repo_id",
|
||||
"u/d",
|
||||
"--policy.type",
|
||||
"act",
|
||||
"--job.target",
|
||||
"a10g-small",
|
||||
"--job.detach",
|
||||
"true",
|
||||
],
|
||||
)
|
||||
|
||||
# Must NOT raise (pre-fix this raised ValueError about missing repo_id)
|
||||
submit_to_hf(cfg)
|
||||
|
||||
assert len(run_job_calls) == 1, "run_job should have been called exactly once"
|
||||
assert cfg.policy.repo_id is not None
|
||||
assert cfg.policy.repo_id.startswith("alice/")
|
||||
call = run_job_calls[0]
|
||||
# The pod runs `lerobot-train --config_path=<staged repo>` on the requested flavor/image.
|
||||
assert call["command"][0] == "lerobot-train"
|
||||
assert call["command"][1].startswith("--config_path=")
|
||||
assert call["flavor"] == "a10g-small"
|
||||
assert call["image"] == "huggingface/lerobot-gpu:latest"
|
||||
# The Hub token is forwarded so the pod can pull the (possibly private) dataset.
|
||||
assert call["secrets"]["HF_TOKEN"] == "tok"
|
||||
# Every job carries the lerobot tag as a queryable label.
|
||||
assert call["labels"].get("lerobot") == "true"
|
||||
|
||||
|
||||
def test_submit_rejects_reward_model_training(monkeypatch):
|
||||
"""Remote training only supports policies; reward-model runs fail fast with a clear error."""
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi)
|
||||
|
||||
cfg = _minimal_cfg()
|
||||
cfg.reward_model = SimpleNamespace(type="reward") # marks this as reward-model training
|
||||
monkeypatch.setattr(cfg, "validate", lambda: None) # skip pretrained-path resolution
|
||||
|
||||
with pytest.raises(ValueError, match="reward model"):
|
||||
submit_to_hf(cfg)
|
||||
|
||||
|
||||
@pytest.mark.timeout(15)
|
||||
def test_submit_returns_when_job_completes(monkeypatch):
|
||||
"""Non-detach path must RETURN (not hang) once the job reaches a terminal stage."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.hf.ensure_dataset_available", lambda *a, **kw: None)
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
|
||||
)
|
||||
monkeypatch.setattr("lerobot.jobs.hf.run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
|
||||
# Job is already COMPLETED on the first poll.
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf.inspect_job",
|
||||
lambda job_id: SimpleNamespace(
|
||||
status=SimpleNamespace(stage=SimpleNamespace(value="COMPLETED"), message=None)
|
||||
),
|
||||
)
|
||||
# Log stream ends immediately.
|
||||
monkeypatch.setattr("lerobot.jobs.hf.fetch_job_logs", lambda job_id, follow=True: iter(()))
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
# Runs in the pytest main thread (signal handler install requires it); the
|
||||
# @timeout marker fails the test instead of hanging if it regresses.
|
||||
submit_to_hf(cfg)
|
||||
|
||||
|
||||
@pytest.mark.timeout(15)
|
||||
def test_submit_returns_on_model_pushed_marker(monkeypatch):
|
||||
"""Finish when the model-pushed log appears, even if the job stage never flips."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.hf.ensure_dataset_available", lambda *a, **kw: None)
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
|
||||
)
|
||||
monkeypatch.setattr("lerobot.jobs.hf.run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
|
||||
# Job stays RUNNING forever — only the log marker can end the command.
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf.inspect_job",
|
||||
lambda job_id: SimpleNamespace(
|
||||
status=SimpleNamespace(stage=SimpleNamespace(value="RUNNING"), message=None)
|
||||
),
|
||||
)
|
||||
pushed_line = "INFO Model pushed to https://huggingface.co/alice/myrun"
|
||||
monkeypatch.setattr("lerobot.jobs.hf.fetch_job_logs", lambda job_id, follow=True: iter([pushed_line]))
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=[
|
||||
"--dataset.repo_id",
|
||||
"u/d",
|
||||
"--policy.type",
|
||||
"act",
|
||||
"--policy.repo_id",
|
||||
"alice/myrun",
|
||||
"--job.target",
|
||||
"a10g-small",
|
||||
],
|
||||
)
|
||||
# Must return via the model-pushed marker despite the perpetual RUNNING stage.
|
||||
submit_to_hf(cfg)
|
||||
|
||||
|
||||
def test_submit_raises_when_wandb_enabled_without_key(monkeypatch):
|
||||
"""wandb.enable with no key reachable anywhere fails fast, before submitting."""
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.hf.resolve_wandb_api_key", lambda: None)
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=[
|
||||
"--dataset.repo_id",
|
||||
"u/d",
|
||||
"--policy.type",
|
||||
"act",
|
||||
"--job.target",
|
||||
"a10g-small",
|
||||
"--wandb.enable",
|
||||
"true",
|
||||
],
|
||||
)
|
||||
with pytest.raises(ValueError, match="WANDB_API_KEY"):
|
||||
submit_to_hf(cfg)
|
||||
|
||||
|
||||
@pytest.mark.timeout(15)
|
||||
def test_submit_raises_when_job_ends_in_error(monkeypatch):
|
||||
"""A terminal non-COMPLETED stage with no model-pushed marker must raise with the status."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.hf.ensure_dataset_available", lambda *a, **kw: None)
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
|
||||
)
|
||||
monkeypatch.setattr("lerobot.jobs.hf.run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
|
||||
# Job fails: a terminal ERROR stage carrying the platform's status message.
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf.inspect_job",
|
||||
lambda job_id: SimpleNamespace(
|
||||
status=SimpleNamespace(stage=SimpleNamespace(value="ERROR"), message="Job timeout")
|
||||
),
|
||||
)
|
||||
# Logs end without the model-pushed marker.
|
||||
monkeypatch.setattr("lerobot.jobs.hf.fetch_job_logs", lambda job_id, follow=True: iter(()))
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
with pytest.raises(RuntimeError, match=r"stage=ERROR \(Job timeout\)"):
|
||||
submit_to_hf(cfg)
|
||||
@@ -0,0 +1,64 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import draccus
|
||||
import pytest
|
||||
|
||||
from lerobot.configs import JobConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
def test_jobconfig_defaults_are_local():
|
||||
cfg = JobConfig()
|
||||
assert cfg.target is None
|
||||
assert cfg.is_remote is False
|
||||
assert cfg.image == "huggingface/lerobot-gpu:latest"
|
||||
assert cfg.timeout == "2d"
|
||||
assert cfg.detach is False
|
||||
|
||||
|
||||
def test_jobconfig_local_string_is_not_remote():
|
||||
assert JobConfig(target="local").is_remote is False
|
||||
|
||||
|
||||
def test_jobconfig_flavor_is_remote():
|
||||
assert JobConfig(target="a10g-small").is_remote is True
|
||||
|
||||
|
||||
def test_train_config_parses_job_target():
|
||||
parsed = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
assert parsed.job.target == "a10g-small"
|
||||
assert parsed.job.is_remote is True
|
||||
assert parsed.save_checkpoint_to_hub is False
|
||||
|
||||
|
||||
def test_save_checkpoint_to_hub_requires_repo_id():
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=[
|
||||
"--dataset.repo_id",
|
||||
"u/d",
|
||||
"--policy.type",
|
||||
"act",
|
||||
"--policy.push_to_hub",
|
||||
"false",
|
||||
"--save_checkpoint_to_hub",
|
||||
"true",
|
||||
],
|
||||
)
|
||||
with pytest.raises(ValueError, match="requires --policy.repo_id"):
|
||||
cfg.validate()
|
||||
@@ -0,0 +1,67 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
|
||||
import draccus
|
||||
import pytest
|
||||
|
||||
# Importing lerobot_train eagerly pulls in lerobot.datasets, which needs the
|
||||
# `dataset` extra. The base CI tier runs without it, so skip the whole module there.
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.configs.train import TrainPipelineConfig # noqa: E402
|
||||
from lerobot.policies.act.configuration_act import (
|
||||
ACTConfig, # noqa: E402, F401 (registers --policy.type act)
|
||||
)
|
||||
from lerobot.scripts.lerobot_train import _remote_target_in_argv, train # noqa: E402
|
||||
|
||||
|
||||
def _set_argv(monkeypatch, *args):
|
||||
monkeypatch.setattr(sys, "argv", ["lerobot-train", *args])
|
||||
|
||||
|
||||
def test_remote_target_detected_space_separated(monkeypatch):
|
||||
_set_argv(monkeypatch, "--policy.type", "act", "--job.target", "a10g-small")
|
||||
assert _remote_target_in_argv() is True
|
||||
|
||||
|
||||
def test_remote_target_detected_equals(monkeypatch):
|
||||
_set_argv(monkeypatch, "--job.target=t4-small")
|
||||
assert _remote_target_in_argv() is True
|
||||
|
||||
|
||||
def test_local_string_is_not_remote(monkeypatch):
|
||||
_set_argv(monkeypatch, "--job.target", "local")
|
||||
assert _remote_target_in_argv() is False
|
||||
|
||||
|
||||
def test_no_target_is_not_remote(monkeypatch):
|
||||
_set_argv(monkeypatch, "--policy.type", "act")
|
||||
assert _remote_target_in_argv() is False
|
||||
|
||||
|
||||
def test_train_dispatches_to_submit_when_remote(monkeypatch):
|
||||
"""A remote --job.target short-circuits train() to the HF Jobs submitter."""
|
||||
import lerobot.jobs
|
||||
|
||||
captured = []
|
||||
monkeypatch.setattr(lerobot.jobs, "submit_to_hf", lambda cfg: captured.append(cfg) or "submitted")
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
# Returns the submitter's result and never enters the local training path.
|
||||
assert train(cfg) == "submitted"
|
||||
assert captured == [cfg]
|
||||
@@ -0,0 +1,54 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from lerobot.utils.hub import find_latest_hub_checkpoint
|
||||
|
||||
|
||||
def _patch_list_files(monkeypatch, files):
|
||||
api = MagicMock()
|
||||
api.list_repo_files.return_value = files
|
||||
# HfApi is imported into lerobot.utils.hub at module load, so patch it there.
|
||||
monkeypatch.setattr("lerobot.utils.hub.HfApi", lambda *a, **k: api)
|
||||
return api
|
||||
|
||||
|
||||
def test_find_latest_hub_checkpoint_picks_highest_step(monkeypatch):
|
||||
_patch_list_files(
|
||||
monkeypatch,
|
||||
[
|
||||
"README.md",
|
||||
"checkpoints/000500/pretrained_model/model.safetensors",
|
||||
"checkpoints/000500/training_state/training_step.json",
|
||||
"checkpoints/020000/pretrained_model/model.safetensors",
|
||||
"checkpoints/001000/training_state/training_step.json",
|
||||
],
|
||||
)
|
||||
# Numeric max, not lexicographic — "020000" beats "001000"/"000500".
|
||||
assert find_latest_hub_checkpoint("u/run") == "checkpoints/020000"
|
||||
|
||||
|
||||
def test_find_latest_hub_checkpoint_ignores_non_step_entries(monkeypatch):
|
||||
_patch_list_files(
|
||||
monkeypatch,
|
||||
["checkpoints/last/pretrained_model/model.safetensors", "config.json"],
|
||||
)
|
||||
# "last" (a symlink target name) is not a numeric step → no resolvable checkpoint.
|
||||
assert find_latest_hub_checkpoint("u/run") is None
|
||||
|
||||
|
||||
def test_find_latest_hub_checkpoint_none_when_no_checkpoints(monkeypatch):
|
||||
_patch_list_files(monkeypatch, ["config.json", "model.safetensors"])
|
||||
assert find_latest_hub_checkpoint("u/run") is None
|
||||
@@ -15,7 +15,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.common.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
@@ -24,6 +26,7 @@ from lerobot.common.train_utils import (
|
||||
load_training_num_processes,
|
||||
load_training_state,
|
||||
load_training_step,
|
||||
push_checkpoint_to_hub,
|
||||
save_checkpoint,
|
||||
save_training_state,
|
||||
save_training_step,
|
||||
@@ -151,3 +154,72 @@ def test_load_training_state_skip_optimizer(tmp_path, optimizer, scheduler):
|
||||
assert loaded_step == 10
|
||||
assert loaded_optimizer is optimizer
|
||||
assert loaded_scheduler is scheduler
|
||||
|
||||
|
||||
def test_push_checkpoint_to_hub_creates_repo_and_uploads(tmp_path, monkeypatch):
|
||||
ckpt = tmp_path / "010000"
|
||||
(ckpt / "pretrained_model").mkdir(parents=True)
|
||||
api = MagicMock()
|
||||
monkeypatch.setattr("lerobot.common.train_utils.HfApi", lambda *a, **k: api)
|
||||
push_checkpoint_to_hub(ckpt, "user/run", private=True)
|
||||
api.create_repo.assert_called_once()
|
||||
assert api.create_repo.call_args.kwargs["private"] is True
|
||||
assert api.create_repo.call_args.kwargs["repo_type"] == "model"
|
||||
api.upload_folder.assert_called_once()
|
||||
kwargs = api.upload_folder.call_args.kwargs
|
||||
assert kwargs["repo_id"] == "user/run"
|
||||
assert kwargs["repo_type"] == "model"
|
||||
assert kwargs["path_in_repo"] == "checkpoints/010000"
|
||||
assert kwargs["folder_path"] == str(ckpt)
|
||||
assert kwargs["commit_message"] == "checkpoint 010000"
|
||||
# A tag named after the checkpoint step is created so the checkpoint can be
|
||||
# recovered with --policy.pretrained_revision instead of a commit sha.
|
||||
api.create_tag.assert_called_once()
|
||||
tag_kwargs = api.create_tag.call_args.kwargs
|
||||
assert tag_kwargs["tag"] == "010000"
|
||||
assert tag_kwargs["revision"] == api.upload_folder.return_value.oid
|
||||
assert tag_kwargs["repo_type"] == "model"
|
||||
assert tag_kwargs["exist_ok"] is True
|
||||
|
||||
|
||||
def test_push_checkpoint_to_hub_defaults_to_hub_default_visibility(tmp_path, monkeypatch):
|
||||
ckpt = tmp_path / "010000"
|
||||
(ckpt / "pretrained_model").mkdir(parents=True)
|
||||
api = MagicMock()
|
||||
monkeypatch.setattr("lerobot.common.train_utils.HfApi", lambda *a, **k: api)
|
||||
push_checkpoint_to_hub(ckpt, "user/run")
|
||||
api.create_repo.assert_called_once()
|
||||
assert api.create_repo.call_args.kwargs["private"] is None
|
||||
|
||||
|
||||
def test_resolve_resume_checkpoint_downloads_latest_and_links(tmp_path, monkeypatch):
|
||||
from lerobot.common import train_utils
|
||||
|
||||
out = tmp_path / "run"
|
||||
|
||||
def fake_snapshot_download(repo_id, repo_type, allow_patterns, local_dir):
|
||||
# Mimic the Hub layout the real download materializes locally.
|
||||
assert allow_patterns == "checkpoints/020000/*"
|
||||
(Path(local_dir) / "checkpoints" / "020000" / "pretrained_model").mkdir(parents=True)
|
||||
return local_dir
|
||||
|
||||
monkeypatch.setattr("lerobot.common.train_utils.snapshot_download", fake_snapshot_download)
|
||||
monkeypatch.setattr(
|
||||
"lerobot.common.train_utils.find_latest_hub_checkpoint", lambda repo_id: "checkpoints/020000"
|
||||
)
|
||||
|
||||
checkpoint_dir = train_utils.resolve_resume_checkpoint("u/run", out)
|
||||
|
||||
assert checkpoint_dir == out / CHECKPOINTS_DIR / "020000"
|
||||
last = out / CHECKPOINTS_DIR / LAST_CHECKPOINT_LINK
|
||||
assert last.is_symlink()
|
||||
# `last` points at the downloaded step dir.
|
||||
assert (last.parent / last.readlink()).resolve() == checkpoint_dir.resolve()
|
||||
|
||||
|
||||
def test_resolve_resume_checkpoint_raises_without_checkpoints(tmp_path, monkeypatch):
|
||||
from lerobot.common import train_utils
|
||||
|
||||
monkeypatch.setattr("lerobot.common.train_utils.find_latest_hub_checkpoint", lambda repo_id: None)
|
||||
with pytest.raises(FileNotFoundError, match="No checkpoint"):
|
||||
train_utils.resolve_resume_checkpoint("u/run", tmp_path / "run")
|
||||
|
||||
Reference in New Issue
Block a user