mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-23 03:07:16 +00:00
Compare commits
13 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ed8694c67f | |||
| 3bbdad8442 | |||
| 05fddeb2ba | |||
| 71c827f892 | |||
| 73782447f2 | |||
| 2d7a42011a | |||
| b06ad40888 | |||
| b3d74f80f0 | |||
| 552b4c3563 | |||
| 8bf6056d14 | |||
| da92db8fc0 | |||
| 2b0834bcb8 | |||
| 287c823f13 |
+1
-1
@@ -138,7 +138,7 @@ lerobot-replay --robot.type=so101_follower --robot.port=<FOLLOWER_PORT> --robot.
|
||||
--dataset.repo_id=${HF_USER}/my_task --dataset.episode=0
|
||||
```
|
||||
|
||||
**4.9 Train** (default: ACT — fastest, lowest memory). Apple silicon: `--policy.device=mps`. See §6/§7 for policy and duration.
|
||||
**4.9 Train** (default: ACT — fastest, lowest memory). Apple silicon: `--policy.device=mps`. No local GPU? Add `--job.target=<flavor>` (e.g. `a10g-small`, list them with `hf jobs hardware`) to run on Hugging Face Jobs instead. See §6/§7 for policy and duration.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
|
||||
@@ -136,6 +136,7 @@ Learn how to implement your own simulation environment or benchmark and distribu
|
||||
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
|
||||
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
|
||||
- **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot.
|
||||
- **[LeLab](https://github.com/huggingface/leLab):** A web interface for LeRobot — teleoperate, calibrate, record datasets, replay, and train your SO arm from the browser, no CLI required.
|
||||
|
||||
## Citation
|
||||
|
||||
|
||||
@@ -120,6 +120,8 @@ lerobot-train \
|
||||
--steps=20000
|
||||
```
|
||||
|
||||
No local GPU? Add `--job.target=<flavor>` (e.g. `a10g-small`) to either command and `lerobot-train` runs it on [Hugging Face Jobs](https://huggingface.co/docs/hub/jobs) instead — it uploads a local-only dataset for you and pushes the trained model. List flavors with `hf jobs hardware`.
|
||||
|
||||
### Inference
|
||||
|
||||
Inference means running the trained policy/model on a robot. For that we use `lerobot-rollout`. You will need to provide a path to your policy. It can be a local path or a path to Hugging Face for example "lerobot/folding_latest". Your cameras configuration needs to match what was used when collecting the dataset. Duration is in seconds if unspecified, it will run forever.
|
||||
|
||||
@@ -96,3 +96,4 @@ Notes:
|
||||
- The leading `nvidia-smi` is a quick sanity check that CUDA is visible inside the container — useful to fail fast if the flavor or driver mismatched.
|
||||
- The default Job timeout is 30 minutes; pass `--timeout 4h` (or longer) for real training.
|
||||
- `--flavor` maps onto the table above: `t4-small`/`t4-medium` (T4, ACT only), `l4x1`/`l4x4` (L4 24 GB), `a10g-small/large/largex2/largex4` (A10G 24 GB scaled out), `a100-large` (A100). For the current full catalogue + pricing see [https://huggingface.co/docs/hub/jobs](https://huggingface.co/docs/hub/jobs).
|
||||
- Prefer not to write the `hf jobs run` wrapper yourself? `lerobot-train` can submit the job for you: just add `--job.target=<flavor>` to a normal training command and it handles dataset upload, log streaming, and the final model push. See the [imitation-learning training guide](./il_robots).
|
||||
|
||||
@@ -518,7 +518,9 @@ If your local computer doesn't have a powerful GPU you could utilize Google Cola
|
||||
|
||||
Hugging Face jobs let's you easily select hardware and run the training in the cloud. So if you don't have a powerful GPU or you need more VRAM or just want to train a model much faster use HF Jobs! It's pay as you go and you simply pay for each second of use, you can see the pricing and additional information [here](https://huggingface.co/docs/hub/jobs).
|
||||
|
||||
To run the training use this command:
|
||||
> **Tip:** if you just want to launch a standard training run, you can skip building the command below and use the integrated **Train on HF Jobs via `--job.target`** flow described further down — `lerobot-train` then submits the job, uploads a local-only dataset for you, and streams the logs.
|
||||
|
||||
To run the training manually use this command:
|
||||
|
||||
<hfoptions id="train_with_hf_jobs">
|
||||
<hfoption id="Command">
|
||||
@@ -591,6 +593,33 @@ Once the training is started you can go to [Jobs](https://huggingface.co/setting
|
||||
|
||||
After training the model will be pushed to hub and you can use it as any other model with LeRobot.
|
||||
|
||||
#### Train on HF Jobs via `--job.target` (integrated CLI)
|
||||
|
||||
`lerobot-train` runs locally by default. To run on a HuggingFace GPU without constructing the Docker command yourself, pass `--job.target` with a hardware flavor name:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/so101_test \
|
||||
--policy.type=act \
|
||||
--policy.repo_id=${HF_USER}/my_policy \
|
||||
--job.target=a10g-small
|
||||
```
|
||||
|
||||
List available flavors and pricing with `hf jobs hardware`. The run streams its logs to your terminal; press Ctrl-C to detach (the job keeps running in the cloud). Re-attach or cancel with:
|
||||
|
||||
```bash
|
||||
hf jobs logs <job-id>
|
||||
hf jobs cancel <job-id>
|
||||
```
|
||||
|
||||
If your dataset exists only locally (not yet on the Hub), it is automatically pushed to a **private** Hub repo so the job can download it by `repo_id` (nothing is made public). The trained model is pushed to the model repo at the end of the run. To also push every intermediate checkpoint to the Hub as it is saved (so you can monitor progress mid-run), add `--save_checkpoint_to_hub=true` — this requires a runtime image that includes this feature.
|
||||
|
||||
Every job (and any dataset pushed by the run) is tagged `lerobot` so it's easy to find on the Hub. Add your own with `--job.tags '["my-tag"]'`.
|
||||
|
||||
By default the job runs until training finishes, with no time limit. Cap it with an HF Jobs duration string if you want a hard ceiling, e.g. `--job.timeout=4h`.
|
||||
|
||||
**Prerequisites:** run `hf auth login` before submitting. For Weights & Biases integration, run `wandb login` or set `WANDB_API_KEY` on your machine — the key is forwarded to the job automatically.
|
||||
|
||||
#### Upload policy checkpoints
|
||||
|
||||
Once training is done, upload the latest checkpoint with:
|
||||
|
||||
@@ -113,6 +113,61 @@ accelerate launch --num_processes=2 $(which lerobot-train) \
|
||||
--policy=act
|
||||
```
|
||||
|
||||
## Training Large Models with FSDP
|
||||
|
||||
DDP replicates the full model on every GPU, so a model that doesn't fit on one GPU won't fit under
|
||||
DDP either. For large models, use **FSDP** (Fully Sharded Data Parallel), which shards parameters,
|
||||
gradients, and optimizer state across GPUs. See the [accelerate FSDP guide](https://huggingface.co/docs/accelerate/usage_guides/fsdp) for background.
|
||||
|
||||
An example on how to launch LeRobot training with FSDP across 4 GPUs (1 machine):
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file fsdp.yaml --num_processes=4 $(which lerobot-train) \
|
||||
--dataset.repo_id=${HF_USER}/my_dataset \
|
||||
--policy.type=<your_policy> \
|
||||
--output_dir=outputs/train/my_policy_fsdp
|
||||
```
|
||||
|
||||
A minimal `fsdp.yaml` (FSDP1; shards params/grads/optimizer — ZeRO-3-equivalent):
|
||||
|
||||
```yaml
|
||||
compute_environment: LOCAL_MACHINE
|
||||
distributed_type: FSDP
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 4
|
||||
fsdp_config:
|
||||
fsdp_version: 1
|
||||
fsdp_sharding_strategy: FULL_SHARD # params + grads + optimizer (ZeRO-3)
|
||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: <YourTransformerBlock> # repeated block class to shard
|
||||
fsdp_use_orig_params: true # required: optimizer is built pre-prepare
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
```
|
||||
|
||||
Set `fsdp_transformer_layer_cls_to_wrap` to your model's repeated transformer-block class so each
|
||||
block is sharded as its own unit. `fsdp_use_orig_params: true` is required because LeRobot builds the
|
||||
optimizer before `accelerator.prepare()`.
|
||||
|
||||
### FSDP checkpoints
|
||||
|
||||
LeRobot gathers the full state dict across all ranks and the main process writes it as a single
|
||||
`model.safetensors`, loadable as usual with `Policy.from_pretrained(...)`. Two things to look out for:
|
||||
|
||||
- **Checkpoints store fp32 weights.** Under mixed precision (`bf16`/`fp16`) FSDP keeps an fp32 master
|
||||
copy, and the checkpoint saves it (~2× the bf16 size on disk) so training can resume consistently
|
||||
with the fp32 optimizer state; `from_pretrained` casts back to the policy dtype on load. FSDP-specific
|
||||
caveat: an fp32 checkpoint is materialized in full precision on the target device _before_ casting,
|
||||
so loading it for inference on a tight GPU can OOM even when the bf16 model would fit — load on CPU
|
||||
first, or cast `model.safetensors` to the deployment dtype offline.
|
||||
- The sharded optimizer state is gathered into a full (world-size-independent) state dict and saved
|
||||
alongside the model in the same `optimizer_state.safetensors` / `optimizer_param_groups.json`
|
||||
format as single-GPU training, so **resume-from-checkpoint is supported** with `--resume=true`.
|
||||
Resume reshards both the model and the optimizer state to the _current_ FSDP topology, so you can
|
||||
resume an FSDP checkpoint on a different number of GPUs. Note that the data sampler is only
|
||||
sample-exact when the world size and batch size match the original run (a warning is logged
|
||||
otherwise); the optimizer/model state itself is unaffected.
|
||||
|
||||
## Notes
|
||||
|
||||
- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
|
||||
|
||||
@@ -442,11 +442,12 @@ class OpenCVCamera(Camera):
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
stop_event = self.stop_event
|
||||
if stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
raw_frame = self._read_from_hardware()
|
||||
processed_frame = self._postprocess_image(raw_frame)
|
||||
|
||||
@@ -471,11 +471,12 @@ class RealSenseCamera(Camera):
|
||||
|
||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
stop_event = self.stop_event
|
||||
if stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
frame = self._read_from_hardware()
|
||||
color_frame_raw = frame.get_color_frame()
|
||||
|
||||
@@ -246,11 +246,12 @@ class ZMQCamera(Camera):
|
||||
"""
|
||||
Internal loop run by the background thread for asynchronous reading.
|
||||
"""
|
||||
if self.stop_event is None:
|
||||
stop_event = self.stop_event
|
||||
if stop_event is None:
|
||||
raise RuntimeError(f"{self}: stop_event is not initialized.")
|
||||
|
||||
failure_count = 0
|
||||
while not self.stop_event.is_set():
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
frame = self._read_from_hardware()
|
||||
capture_time = time.perf_counter()
|
||||
|
||||
@@ -21,6 +21,7 @@ from torch.optim.lr_scheduler import LRScheduler
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.optim import (
|
||||
load_optimizer_state,
|
||||
load_optimizer_state_dict,
|
||||
load_scheduler_state,
|
||||
save_optimizer_state,
|
||||
save_scheduler_state,
|
||||
@@ -98,6 +99,8 @@ def save_checkpoint(
|
||||
postprocessor: PolicyProcessorPipeline | None = None,
|
||||
num_processes: int | None = None,
|
||||
batch_size: int | None = None,
|
||||
model_state_dict: dict | None = None,
|
||||
optim_state_dict: dict | None = None,
|
||||
) -> None:
|
||||
"""This function creates the following directory structure:
|
||||
|
||||
@@ -127,9 +130,18 @@ def save_checkpoint(
|
||||
resume. Defaults to None (not recorded).
|
||||
batch_size (int | None, optional): Per-process batch size to record for sample-exact
|
||||
resume. Defaults to None (not recorded).
|
||||
model_state_dict: Pre-gathered full (unsharded) model state dict. Required under FSDP,
|
||||
where `policy.state_dict()` would return sharded tensors; the caller gathers it via a
|
||||
cross-rank collective and passes it here so rank 0 can write it directly. It holds
|
||||
FSDP's fp32 master weights and is saved as-is (the loader casts to the policy dtype on
|
||||
read). When None (DDP / single-GPU), the model is saved the normal way. Defaults to None.
|
||||
optim_state_dict: Pre-gathered full (unsharded) optimizer state dict. Required under FSDP
|
||||
(gathered alongside `model_state_dict` via `gather_fsdp_state_dicts`); saved in the same
|
||||
safetensors format as the single-GPU path. When None, `optimizer.state_dict()` is used.
|
||||
Defaults to None.
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
policy.save_pretrained(pretrained_dir, state_dict=model_state_dict)
|
||||
cfg.save_pretrained(pretrained_dir)
|
||||
if cfg.peft is not None:
|
||||
# When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the
|
||||
@@ -140,7 +152,13 @@ def save_checkpoint(
|
||||
if postprocessor is not None:
|
||||
postprocessor.save_pretrained(pretrained_dir)
|
||||
save_training_state(
|
||||
checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes, batch_size=batch_size
|
||||
checkpoint_dir,
|
||||
step,
|
||||
optimizer,
|
||||
scheduler,
|
||||
num_processes=num_processes,
|
||||
batch_size=batch_size,
|
||||
optim_state_dict=optim_state_dict,
|
||||
)
|
||||
|
||||
|
||||
@@ -151,6 +169,7 @@ def save_training_state(
|
||||
scheduler: LRScheduler | None = None,
|
||||
num_processes: int | None = None,
|
||||
batch_size: int | None = None,
|
||||
optim_state_dict: dict | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Saves the training step, optimizer state, scheduler state, and rng state.
|
||||
@@ -164,19 +183,21 @@ def save_training_state(
|
||||
Defaults to None.
|
||||
num_processes (int | None, optional): Distributed world size to record. Defaults to None.
|
||||
batch_size (int | None, optional): Per-process batch size to record. Defaults to None.
|
||||
optim_state_dict: Pre-gathered full optimizer state dict (for FSDP). Saved instead of
|
||||
`optimizer.state_dict()` when provided. Defaults to None.
|
||||
"""
|
||||
save_dir = checkpoint_dir / TRAINING_STATE_DIR
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size)
|
||||
save_rng_state(save_dir)
|
||||
if optimizer is not None:
|
||||
save_optimizer_state(optimizer, save_dir)
|
||||
save_optimizer_state(optimizer, save_dir, optim_state_dict=optim_state_dict)
|
||||
if scheduler is not None:
|
||||
save_scheduler_state(scheduler, save_dir)
|
||||
|
||||
|
||||
def load_training_state(
|
||||
checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None
|
||||
checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None, load_optimizer: bool = True
|
||||
) -> tuple[int, Optimizer, LRScheduler | None]:
|
||||
"""
|
||||
Loads the training step, optimizer state, scheduler state, and rng state.
|
||||
@@ -186,6 +207,10 @@ def load_training_state(
|
||||
checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir.
|
||||
optimizer (Optimizer): The optimizer to load the state_dict to.
|
||||
scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None).
|
||||
load_optimizer (bool, optional): Whether to load the optimizer state from disk. Defaults to
|
||||
True. Set to False under FSDP, where the sharded optimizer state must be loaded after
|
||||
`accelerator.prepare()` via `load_fsdp_optimizer_state` (the optimizer is returned
|
||||
untouched here).
|
||||
|
||||
Raises:
|
||||
NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir
|
||||
@@ -200,8 +225,86 @@ def load_training_state(
|
||||
|
||||
load_rng_state(training_state_dir)
|
||||
step = load_training_step(training_state_dir)
|
||||
optimizer = load_optimizer_state(optimizer, training_state_dir)
|
||||
if load_optimizer:
|
||||
optimizer = load_optimizer_state(optimizer, training_state_dir)
|
||||
if scheduler is not None:
|
||||
scheduler = load_scheduler_state(scheduler, training_state_dir)
|
||||
|
||||
return step, optimizer, scheduler
|
||||
|
||||
|
||||
def gather_fsdp_state_dicts(model, optimizer) -> tuple[dict, dict]:
|
||||
"""Gather the full (unsharded) model and optimizer state dicts under FSDP.
|
||||
|
||||
`model.state_dict()` and `FSDP.optim_state_dict(...)` are cross-rank collectives, so this must be
|
||||
called on *every* rank with the prepared (FSDP-wrapped) `model` and `optimizer`. With
|
||||
`rank0_only=True` and `offload_to_cpu=True`, every rank runs the all-gather but only rank 0
|
||||
materializes the full dicts (the others get empty dicts) and they are kept on CPU to bound GPU
|
||||
memory. The returned optimizer state dict is keyed by parameter FQNs and is world-size
|
||||
independent; `load_fsdp_optimizer_state` reshards it on resume.
|
||||
|
||||
Returns:
|
||||
(model_state_dict, optim_state_dict): full dicts on rank 0, empty dicts on other ranks.
|
||||
"""
|
||||
from torch.distributed.fsdp import (
|
||||
FullOptimStateDictConfig,
|
||||
FullStateDictConfig,
|
||||
FullyShardedDataParallel as FSDP, # noqa F401
|
||||
StateDictType,
|
||||
)
|
||||
|
||||
state_cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
optim_cfg = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, state_cfg, optim_cfg):
|
||||
model_state_dict = model.state_dict()
|
||||
optim_state_dict = FSDP.optim_state_dict(model, optimizer)
|
||||
return model_state_dict, optim_state_dict
|
||||
|
||||
|
||||
def load_fsdp_optimizer_state(model, optimizer, checkpoint_dir: Path) -> None:
|
||||
"""Load the FSDP optimizer state (saved as safetensors) and reshard it into the optimizer.
|
||||
|
||||
This is a cross-rank collective and must be called on every rank *after* `accelerator.prepare()`
|
||||
with the prepared (FSDP-wrapped) `model` and `optimizer`. The saved state is the full,
|
||||
world-size-independent optimizer state (keyed by parameter FQNs); `FSDP.optim_state_dict_to_load`
|
||||
reshards it to the current FSDP topology, so resume on a different number of GPUs works.
|
||||
"""
|
||||
from torch.distributed.fsdp import (
|
||||
FullOptimStateDictConfig,
|
||||
FullStateDictConfig,
|
||||
FullyShardedDataParallel as FSDP, # noqa F401
|
||||
StateDictType,
|
||||
)
|
||||
|
||||
# Every rank reads the same full state from the (shared) checkpoint dir, so rank0_only=False.
|
||||
full_osd = load_optimizer_state_dict(checkpoint_dir / TRAINING_STATE_DIR)
|
||||
state_cfg = FullStateDictConfig(rank0_only=False)
|
||||
optim_cfg = FullOptimStateDictConfig(rank0_only=False)
|
||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, state_cfg, optim_cfg):
|
||||
sharded_osd = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=full_osd)
|
||||
optimizer.load_state_dict(sharded_osd)
|
||||
|
||||
|
||||
def push_checkpoint_to_hub(
|
||||
checkpoint_dir: Path,
|
||||
repo_id: str,
|
||||
*,
|
||||
private: bool | None = None,
|
||||
) -> None:
|
||||
"""Upload a saved checkpoint directory to the Hub under checkpoints/<name>/.
|
||||
|
||||
Called once per save step when save_checkpoint_to_hub is enabled, so a
|
||||
timed-out or crashed run still leaves recoverable checkpoints on the Hub.
|
||||
The model repo is created idempotently.
|
||||
"""
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
api.create_repo(repo_id=repo_id, repo_type="model", private=private, exist_ok=True)
|
||||
api.upload_folder(
|
||||
folder_path=str(checkpoint_dir),
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
path_in_repo=f"checkpoints/{checkpoint_dir.name}",
|
||||
commit_message=f"checkpoint {checkpoint_dir.name}",
|
||||
)
|
||||
|
||||
@@ -180,24 +180,26 @@ class WandBLogger:
|
||||
self._wandb_custom_step_key.add(new_custom_key)
|
||||
self._wandb.define_metric(new_custom_key, hidden=True)
|
||||
|
||||
batch_data = {}
|
||||
for k, v in d.items():
|
||||
# Skip the custom step key here, it's added to the batch below.
|
||||
if custom_step_key is not None and k == custom_step_key:
|
||||
continue
|
||||
|
||||
if not isinstance(v, (int | float | str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
|
||||
)
|
||||
continue
|
||||
|
||||
# Do not log the custom step key itself.
|
||||
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
|
||||
continue
|
||||
batch_data[f"{mode}/{k}"] = v
|
||||
|
||||
if batch_data:
|
||||
if custom_step_key is not None:
|
||||
value_custom_step = d[custom_step_key]
|
||||
data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
|
||||
self._wandb.log(data)
|
||||
continue
|
||||
|
||||
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
|
||||
batch_data[f"{mode}/{custom_step_key}"] = d[custom_step_key]
|
||||
self._wandb.log(batch_data)
|
||||
else:
|
||||
self._wandb.log(data=batch_data, step=step)
|
||||
|
||||
def log_video(self, video_path: str, step: int, mode: str = "train"):
|
||||
if mode not in {"train", "eval"}:
|
||||
|
||||
@@ -123,3 +123,31 @@ class PeftConfig:
|
||||
# If None, the PEFT library defaults to alpha=8, which may dampen high-rank adapters.
|
||||
# Common values are r (alpha == rank) or 2*r.
|
||||
lora_alpha: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class JobConfig:
|
||||
# Where training runs. None (omitted) or "local" runs on this machine.
|
||||
# Any other value is an HF Jobs flavor and submits the run to HF Jobs.
|
||||
# List available flavors + pricing with `hf jobs hardware` command.
|
||||
target: str | None = None
|
||||
# Runtime image for the remote job (ignored for local runs).
|
||||
image: str = "huggingface/lerobot-gpu:latest"
|
||||
# Max wall-clock for the remote job as an HF Jobs duration string (e.g. "2h").
|
||||
# None (default) imposes no timeout — the job runs until the command finishes.
|
||||
timeout: str | None = None
|
||||
# Submit and exit instead of streaming the job logs in the foreground.
|
||||
detach: bool = False
|
||||
# Extra tags attached to the HF job and to any dataset this run pushes to the
|
||||
# Hub. A "lerobot" tag is always added; e.g. --job.tags '["lelab"]' adds more.
|
||||
tags: list[str] = field(default_factory=list)
|
||||
|
||||
@staticmethod
|
||||
def is_remote_target(target: str | None) -> bool:
|
||||
"""True when `target` names an HF Jobs flavor rather than a local run."""
|
||||
return target not in (None, "local")
|
||||
|
||||
@property
|
||||
def is_remote(self) -> bool:
|
||||
"""True when training should run on HF Jobs rather than this machine."""
|
||||
return self.is_remote_target(self.target)
|
||||
|
||||
@@ -79,6 +79,8 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): # type: igno
|
||||
# Either the repo ID of a model hosted on the Hub or a path to a directory containing weights
|
||||
# saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch.
|
||||
pretrained_path: Path | None = None
|
||||
# Optional Hub revision (commit hash, branch, or tag) to pin the pretrained model version.
|
||||
pretrained_revision: str | None = None
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
|
||||
@@ -56,6 +56,8 @@ class RewardModelConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
device: str | None = None
|
||||
|
||||
pretrained_path: str | None = None
|
||||
# Optional Hub revision (commit hash, branch, or tag) to pin the pretrained reward model version.
|
||||
pretrained_revision: str | None = None
|
||||
|
||||
push_to_hub: bool = False
|
||||
repo_id: str | None = None
|
||||
|
||||
@@ -30,7 +30,7 @@ from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||
|
||||
from . import parser
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .default import DatasetConfig, EvalConfig, JobConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .rewards import RewardModelConfig
|
||||
|
||||
@@ -113,6 +113,13 @@ class TrainPipelineConfig(HubMixin):
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
peft: PeftConfig | None = None
|
||||
|
||||
# Where to run training (local default, or an HF Jobs flavor). See JobConfig.
|
||||
job: JobConfig = field(default_factory=JobConfig)
|
||||
# Push each saved checkpoint to the Hub (policy.repo_id) as it is written, not
|
||||
# just the final model (useful to monitor progress mid-run). Optional; the
|
||||
# final model is pushed regardless. Works the same locally and remotely.
|
||||
save_checkpoint_to_hub: bool = False
|
||||
|
||||
# Sample weighting configuration (e.g., for RA-BC training)
|
||||
sample_weighting: SampleWeightingConfig | None = None
|
||||
|
||||
@@ -211,6 +218,9 @@ class TrainPipelineConfig(HubMixin):
|
||||
if hasattr(active_cfg, "push_to_hub") and active_cfg.push_to_hub and not active_cfg.repo_id:
|
||||
raise ValueError("'repo_id' argument missing. Please specify it to push the model to the hub.")
|
||||
|
||||
if self.save_checkpoint_to_hub and not (self.policy is not None and self.policy.repo_id):
|
||||
raise ValueError("save_checkpoint_to_hub requires --policy.repo_id.")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""Keys for draccus pretrained-path loading."""
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -709,7 +710,7 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
obj.root.mkdir(parents=True, exist_ok=False)
|
||||
|
||||
features = {**features, **DEFAULT_FEATURES}
|
||||
features = {**deepcopy(features), **DEFAULT_FEATURES}
|
||||
_validate_feature_names(features)
|
||||
|
||||
obj.tasks = None
|
||||
|
||||
@@ -74,6 +74,8 @@ class DatasetReader:
|
||||
self.episodes = episodes
|
||||
self._tolerance_s = tolerance_s
|
||||
self._video_backend = video_backend
|
||||
if image_transforms is not None and not callable(image_transforms):
|
||||
raise TypeError("image_transforms must be callable or None.")
|
||||
self._image_transforms = image_transforms
|
||||
self._return_uint8 = return_uint8
|
||||
|
||||
@@ -86,6 +88,16 @@ class DatasetReader:
|
||||
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
||||
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
||||
|
||||
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
||||
"""Replace the transform applied to visual observations."""
|
||||
if image_transforms is not None and not callable(image_transforms):
|
||||
raise TypeError("image_transforms must be callable or None.")
|
||||
self._image_transforms = image_transforms
|
||||
|
||||
def clear_image_transforms(self) -> None:
|
||||
"""Remove the transform applied to visual observations."""
|
||||
self._image_transforms = None
|
||||
|
||||
def try_load(self) -> bool:
|
||||
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
||||
try:
|
||||
|
||||
@@ -27,6 +27,7 @@ import logging
|
||||
import shutil
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
@@ -1101,7 +1102,9 @@ def _copy_episodes_metadata_and_stats(
|
||||
if dst_meta.video_keys and src_dataset.meta.video_keys:
|
||||
for key in dst_meta.video_keys:
|
||||
if key in src_dataset.meta.features:
|
||||
dst_meta.info.features[key]["info"] = src_dataset.meta.info.features[key].get("info", {})
|
||||
dst_meta.info.features[key]["info"] = deepcopy(
|
||||
src_dataset.meta.info.features[key].get("info", {})
|
||||
)
|
||||
|
||||
write_info(dst_meta.info, dst_meta.root)
|
||||
|
||||
|
||||
@@ -154,7 +154,7 @@ def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
|
||||
Returns:
|
||||
dict: The statistics dictionary with values cast to numpy arrays.
|
||||
"""
|
||||
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
|
||||
stats = {key: np.atleast_1d(np.array(value)) for key, value in flatten_dict(stats).items()}
|
||||
return unflatten_dict(stats)
|
||||
|
||||
|
||||
|
||||
@@ -201,8 +201,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self._requested_root = Path(root) if root else None
|
||||
self.reader = None
|
||||
self.set_image_transforms(image_transforms)
|
||||
self.delta_timestamps = delta_timestamps
|
||||
self.tolerance_s = tolerance_s
|
||||
self.revision = revision if revision else CODEBASE_VERSION
|
||||
@@ -249,6 +247,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
image_transforms=image_transforms,
|
||||
return_uint8=self._return_uint8,
|
||||
)
|
||||
self.image_transforms = image_transforms
|
||||
|
||||
# Load actual data
|
||||
if force_cache_sync or not self.reader.try_load():
|
||||
@@ -505,15 +504,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
||||
"""Replace the transform applied to visual observations."""
|
||||
if image_transforms is not None and not callable(image_transforms):
|
||||
raise TypeError("image_transforms must be callable or None.")
|
||||
self._ensure_reader().set_image_transforms(image_transforms)
|
||||
self.image_transforms = image_transforms
|
||||
if self.reader is not None:
|
||||
self.reader._image_transforms = image_transforms
|
||||
|
||||
def clear_image_transforms(self) -> None:
|
||||
"""Remove the transform applied to visual observations."""
|
||||
self.set_image_transforms(None)
|
||||
if self.reader is not None:
|
||||
self.reader.set_image_transforms(None)
|
||||
self.image_transforms = None
|
||||
|
||||
# ── Hub methods (stay on facade) ──────────────────────────────────
|
||||
|
||||
|
||||
@@ -126,6 +126,26 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
if "camera_obs" in observations:
|
||||
return_observations[f"{OBS_STR}.camera_obs"] = observations["camera_obs"]
|
||||
|
||||
# Pass through any remaining ndarray/tensor keys not already handled above,
|
||||
# so env plugins can expose extra observation keys via get_env_processors().
|
||||
_handled = {"pixels", "environment_state", "agent_pos", "robot_state", "policy", "camera_obs"}
|
||||
for key, value in observations.items():
|
||||
if key in _handled:
|
||||
continue
|
||||
target = f"{OBS_STR}.{key}"
|
||||
if target in return_observations:
|
||||
continue
|
||||
if isinstance(value, np.ndarray):
|
||||
val = torch.from_numpy(value).float()
|
||||
if val.dim() == 1:
|
||||
val = val.unsqueeze(0)
|
||||
return_observations[target] = val
|
||||
elif isinstance(value, Tensor):
|
||||
val = value.float()
|
||||
if val.dim() == 1:
|
||||
val = val.unsqueeze(0)
|
||||
return_observations[target] = val
|
||||
|
||||
return return_observations
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .hf import submit_to_hf
|
||||
|
||||
__all__ = ["submit_to_hf"]
|
||||
@@ -0,0 +1,56 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Make a training dataset reachable from an HF Job pod.
|
||||
|
||||
The pod can't see the host's ~/.cache/huggingface/lerobot, so the dataset has to
|
||||
live on the Hub: the pod downloads it by repo_id at train time (the forwarded
|
||||
HF_TOKEN covers private datasets). A dataset already on the Hub is used as-is; a
|
||||
local-only dataset is pushed to a PRIVATE repo first (never public).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub.errors import RepositoryNotFoundError
|
||||
|
||||
|
||||
def ensure_dataset_available(repo_id: str, *, api, tags: list[str] | None = None) -> None:
|
||||
"""Ensure repo_id resolves on the Hub, pushing a local-only dataset privately first.
|
||||
|
||||
`tags` are attached to the dataset only when we push it (an already-on-Hub
|
||||
dataset is left untouched). Raises RuntimeError if the dataset is neither on
|
||||
the Hub nor in the local cache.
|
||||
"""
|
||||
try:
|
||||
api.dataset_info(repo_id)
|
||||
return
|
||||
except RepositoryNotFoundError:
|
||||
pass
|
||||
|
||||
cache_root = Path(os.environ.get("HF_LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser()
|
||||
local_present = (cache_root / repo_id / "meta" / "info.json").is_file()
|
||||
if not local_present:
|
||||
raise RuntimeError(
|
||||
f"Dataset '{repo_id}' is neither on the Hub nor in the local cache "
|
||||
f"({cache_root}). Record or download it first."
|
||||
)
|
||||
|
||||
print(f"[dataset] '{repo_id}' is local-only; pushing to a PRIVATE Hub repo...")
|
||||
# Lazy import: LeRobotDataset pulls in heavy dataset deps; defer until actually needed.
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
|
||||
LeRobotDataset(repo_id).push_to_hub(private=True, tags=tags)
|
||||
print(f"[dataset] '{repo_id}' uploaded (private). The job will download it by repo_id.")
|
||||
@@ -0,0 +1,332 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Run a lerobot training on HF Jobs (HuggingFace GPUs).
|
||||
|
||||
Ported and simplified from lelab's runners/hf_cloud.py: no UI log queue, no
|
||||
registry — just submit and stream to stdout.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import datetime as dt
|
||||
import io
|
||||
import json
|
||||
import netrc
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import tempfile
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import get_token
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
_SLUG_RE = re.compile(r"[^a-zA-Z0-9._-]+")
|
||||
|
||||
_TERMINAL_STAGES = {"COMPLETED", "CANCELED", "ERROR", "DELETED"}
|
||||
|
||||
# Always attached to remote jobs and pushed datasets so LeRobot-originated work
|
||||
# is identifiable on the Hub; callers (e.g. LeLab) add their own via --job.tags.
|
||||
LEROBOT_TAG = "lerobot"
|
||||
|
||||
|
||||
def resolve_job_tags(extra: list[str] | None) -> list[str]:
|
||||
"""Return the tag list for a run: the lerobot tag plus any extras, deduped, order-stable."""
|
||||
tags = [LEROBOT_TAG, *(extra or [])]
|
||||
seen: set[str] = set()
|
||||
return [t for t in tags if not (t in seen or seen.add(t))]
|
||||
|
||||
|
||||
def resolve_wandb_api_key() -> str | None:
|
||||
"""Host's wandb key for forwarding to the job: $WANDB_API_KEY, else ~/.netrc."""
|
||||
key = os.environ.get("WANDB_API_KEY")
|
||||
if key:
|
||||
return key
|
||||
try:
|
||||
rc = netrc.netrc()
|
||||
except (FileNotFoundError, netrc.NetrcParseError, OSError):
|
||||
return None
|
||||
auth = rc.authenticators("api.wandb.ai")
|
||||
if auth is None:
|
||||
return None
|
||||
_login, _account, password = auth
|
||||
return password or None
|
||||
|
||||
|
||||
def build_repo_id(username: str, job_name: str, now: dt.datetime) -> str:
|
||||
"""Generate the model repo id for a remote run: <user>/<job_name>_<timestamp>."""
|
||||
slug = _SLUG_RE.sub("-", job_name).strip("-") or "train"
|
||||
stamp = now.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
return f"{username}/{slug}_{stamp}"
|
||||
|
||||
|
||||
def build_remote_config_file(cfg, repo_id: str, dest: Path, tags: list[str] | None = None) -> Path:
|
||||
"""Write a train_config.json for the pod, with remote overrides applied.
|
||||
|
||||
The pod runs `lerobot-train --config_path=<dest>` and downloads the dataset
|
||||
by repo_id into its own cache. Client-only fields are stripped so the config
|
||||
is accepted by the trainer image: `job` (pure client orchestration) is always
|
||||
removed, and `save_checkpoint_to_hub` is removed unless explicitly enabled —
|
||||
older lerobot images reject unknown keys, so the default keeps the config
|
||||
compatible with the released `lerobot-gpu` image. `tags` are merged into
|
||||
policy.tags so the trained model the pod pushes carries them too.
|
||||
"""
|
||||
remote = copy.deepcopy(cfg)
|
||||
remote.policy.push_to_hub = True
|
||||
remote.policy.repo_id = repo_id
|
||||
# Don't pin the client's resolved device (e.g. "mps"); let the pod auto-detect its GPU.
|
||||
remote.policy.device = None
|
||||
# Drop any host-local dataset root; the pod resolves the dataset by repo_id.
|
||||
remote.dataset.root = None
|
||||
if tags:
|
||||
existing = list(remote.policy.tags or [])
|
||||
remote.policy.tags = existing + [t for t in tags if t not in existing]
|
||||
|
||||
# Round-trip through draccus to get the canonical, pod-parseable layout, then
|
||||
# drop the keys the released trainer image doesn't know about.
|
||||
buf = io.StringIO()
|
||||
with draccus.config_type("json"):
|
||||
draccus.dump(remote, buf, indent=4)
|
||||
data = json.loads(buf.getvalue())
|
||||
data.pop("job", None)
|
||||
if not remote.save_checkpoint_to_hub:
|
||||
data.pop("save_checkpoint_to_hub", None)
|
||||
|
||||
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||
dest.write_text(json.dumps(data, indent=4))
|
||||
return dest
|
||||
|
||||
|
||||
def _stage_config_on_hub(cfg, repo_id: str, token: str, tags: list[str] | None = None) -> str:
|
||||
"""Upload train_config.json to the model repo and return the repo_id for --config_path."""
|
||||
from huggingface_hub import create_repo, upload_file
|
||||
|
||||
create_repo(repo_id, repo_type="model", private=True, exist_ok=True, token=token)
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
config_path = build_remote_config_file(cfg, repo_id, Path(tmp) / "train_config.json", tags=tags)
|
||||
upload_file(
|
||||
path_or_fileobj=config_path,
|
||||
path_in_repo="train_config.json",
|
||||
repo_id=repo_id,
|
||||
repo_type="model",
|
||||
token=token,
|
||||
)
|
||||
return repo_id
|
||||
|
||||
|
||||
def _tail_logs(
|
||||
job_id: str,
|
||||
done: threading.Event,
|
||||
success_marker: str | None = None,
|
||||
success_event: threading.Event | None = None,
|
||||
) -> None:
|
||||
"""Stream job logs to stdout, reconnecting on dropped streams until done is set.
|
||||
|
||||
Each reconnect re-fetches the full buffered log, so we track how many lines
|
||||
were already printed and skip them — otherwise a fast-failing job's traceback
|
||||
gets reprinted on every reconnect.
|
||||
|
||||
When `success_marker` appears in a line, set `success_event` and `done` so the
|
||||
caller can finish as soon as the trained model lands on the Hub, rather than
|
||||
waiting out the platform's post-run finalization (which can add ~30s).
|
||||
"""
|
||||
from huggingface_hub import fetch_job_logs
|
||||
|
||||
printed = 0
|
||||
while not done.is_set():
|
||||
try:
|
||||
seen = 0
|
||||
for line in fetch_job_logs(job_id=job_id, follow=True):
|
||||
seen += 1
|
||||
if seen <= printed:
|
||||
continue # already shown on a previous connection
|
||||
printed = seen
|
||||
# fetch_job_logs yields SSE data without trailing newlines, so add one
|
||||
# per entry — otherwise all log lines concatenate onto a single line.
|
||||
print(line.rstrip("\n"), flush=True)
|
||||
if success_marker and success_event is not None and success_marker in line:
|
||||
success_event.set()
|
||||
done.set()
|
||||
return
|
||||
if done.is_set():
|
||||
return
|
||||
# Stream closed cleanly. Wait a moment so the status poller can mark
|
||||
# the job terminal before we reconnect (avoids re-tailing the buffer).
|
||||
if done.wait(3):
|
||||
return
|
||||
except Exception:
|
||||
if done.wait(2):
|
||||
return
|
||||
|
||||
|
||||
def _poll_until_done(
|
||||
job_id: str,
|
||||
done: threading.Event,
|
||||
poll_interval: float = 5.0,
|
||||
status_holder: dict | None = None,
|
||||
max_failures: int = 6,
|
||||
) -> str | None:
|
||||
"""Poll inspect_job until a terminal stage or until `done` is set.
|
||||
|
||||
Returns the terminal stage string, or None if `done` was set first (detach)
|
||||
or after `max_failures` consecutive inspect_job errors. When a terminal stage
|
||||
is reached and `status_holder` is given, records `status_holder["message"]`
|
||||
(the platform's status message, e.g. "Job timeout").
|
||||
"""
|
||||
from huggingface_hub import inspect_job
|
||||
|
||||
failures = 0
|
||||
while not done.is_set():
|
||||
try:
|
||||
info = inspect_job(job_id=job_id)
|
||||
failures = 0
|
||||
stage = info.status.stage.value
|
||||
if stage in _TERMINAL_STAGES:
|
||||
if status_holder is not None:
|
||||
status_holder["message"] = getattr(info.status, "message", None)
|
||||
done.set()
|
||||
return stage
|
||||
except Exception:
|
||||
failures += 1
|
||||
if failures >= max_failures:
|
||||
done.set()
|
||||
return None
|
||||
done.wait(poll_interval)
|
||||
return None
|
||||
|
||||
|
||||
def submit_to_hf(cfg: TrainPipelineConfig) -> None:
|
||||
"""Submit a training job to HF Jobs infrastructure.
|
||||
|
||||
Validates cfg, resolves credentials, stages the config on the Hub, submits
|
||||
the job, then either tails logs until completion or detaches immediately.
|
||||
Ctrl-C detaches without cancelling the remote job.
|
||||
"""
|
||||
from huggingface_hub import HfApi, run_job
|
||||
|
||||
from lerobot.jobs.dataset import ensure_dataset_available
|
||||
|
||||
token = get_token()
|
||||
if not token:
|
||||
raise RuntimeError("Not logged in to Hugging Face. Run `hf auth login` first.")
|
||||
|
||||
api = HfApi(token=token)
|
||||
user_info = api.whoami(token=token)
|
||||
username = user_info["name"]
|
||||
|
||||
now = dt.datetime.now()
|
||||
if cfg.policy is not None:
|
||||
base_name = cfg.job_name or cfg.policy.type
|
||||
repo_id = cfg.policy.repo_id or build_repo_id(username, base_name, now)
|
||||
cfg.policy.repo_id = repo_id
|
||||
cfg.policy.push_to_hub = True
|
||||
else:
|
||||
# Path-based policy is resolved inside validate(); fall back to a generic slug.
|
||||
repo_id = build_repo_id(username, cfg.job_name or "train", now)
|
||||
|
||||
cfg.validate()
|
||||
|
||||
secrets: dict[str, str] = {"HF_TOKEN": token}
|
||||
if cfg.wandb.enable:
|
||||
wandb_key = resolve_wandb_api_key()
|
||||
if wandb_key is None:
|
||||
raise ValueError(
|
||||
"wandb is enabled but no WANDB_API_KEY found. "
|
||||
"Set it via `export WANDB_API_KEY=...` or add it to ~/.netrc."
|
||||
)
|
||||
secrets["WANDB_API_KEY"] = wandb_key
|
||||
|
||||
tags = resolve_job_tags(cfg.job.tags)
|
||||
ensure_dataset_available(cfg.dataset.repo_id, api=api, tags=tags)
|
||||
|
||||
config_repo_id = _stage_config_on_hub(cfg, repo_id, token, tags=tags)
|
||||
command = ["lerobot-train", f"--config_path={config_repo_id}"]
|
||||
|
||||
print(f"Submitting job to HF Jobs (flavor={cfg.job.target}, image={cfg.job.image}) ...")
|
||||
job_info = run_job(
|
||||
image=cfg.job.image,
|
||||
command=command,
|
||||
flavor=cfg.job.target,
|
||||
secrets=secrets,
|
||||
timeout=cfg.job.timeout,
|
||||
# HF Jobs labels are key/value; expose each tag as a queryable label.
|
||||
labels=dict.fromkeys(tags, "true"),
|
||||
)
|
||||
job_id = job_info.id
|
||||
job_url = getattr(job_info, "url", None)
|
||||
print(f"Job submitted: {job_id}")
|
||||
if job_url:
|
||||
print(f" Job page: {job_url}")
|
||||
print(f" Model repo: https://huggingface.co/{repo_id}")
|
||||
print(f" Monitor: hf jobs logs {job_id}")
|
||||
print(f" Cancel: hf jobs cancel {job_id}")
|
||||
|
||||
if cfg.job.detach:
|
||||
return
|
||||
|
||||
done = threading.Event()
|
||||
detached = threading.Event()
|
||||
pushed_ok = threading.Event()
|
||||
stage_holder: dict[str, str | None] = {}
|
||||
|
||||
def _poll() -> None:
|
||||
stage_holder["stage"] = _poll_until_done(job_id, done, status_holder=stage_holder)
|
||||
|
||||
poll_thread = threading.Thread(target=_poll, daemon=True)
|
||||
poll_thread.start()
|
||||
# Finish as soon as the model is pushed, rather than waiting out the platform's
|
||||
# post-run finalization before the job stage flips to COMPLETED.
|
||||
success_marker = f"Model pushed to https://huggingface.co/{repo_id}"
|
||||
log_thread = threading.Thread(
|
||||
target=_tail_logs, args=(job_id, done, success_marker, pushed_ok), daemon=True
|
||||
)
|
||||
log_thread.start()
|
||||
|
||||
def _detach(sig, frame):
|
||||
detached.set()
|
||||
done.set()
|
||||
print("\nDetached. Job is still running.")
|
||||
print(f" Monitor: hf jobs logs {job_id}")
|
||||
print(f" Cancel: hf jobs cancel {job_id}")
|
||||
|
||||
original_sigint = signal.getsignal(signal.SIGINT)
|
||||
signal.signal(signal.SIGINT, _detach)
|
||||
try:
|
||||
# Timeout-based join so SIGINT is delivered to the main thread promptly.
|
||||
while poll_thread.is_alive():
|
||||
poll_thread.join(timeout=0.5)
|
||||
log_thread.join(timeout=5)
|
||||
finally:
|
||||
signal.signal(signal.SIGINT, original_sigint)
|
||||
|
||||
if detached.is_set():
|
||||
return
|
||||
|
||||
if pushed_ok.is_set():
|
||||
print(f"\nTraining complete — model pushed to https://huggingface.co/{repo_id}")
|
||||
return
|
||||
|
||||
stage = stage_holder.get("stage")
|
||||
if stage != "COMPLETED":
|
||||
message = stage_holder.get("message")
|
||||
detail = f" ({message})" if message else ""
|
||||
raise RuntimeError(
|
||||
f"Job {job_id} ended with stage={stage}{detail}. Check logs: hf jobs logs {job_id}"
|
||||
)
|
||||
@@ -20,6 +20,7 @@ from .optimizers import (
|
||||
SGDConfig as SGDConfig,
|
||||
XVLAAdamWConfig as XVLAAdamWConfig,
|
||||
load_optimizer_state,
|
||||
load_optimizer_state_dict,
|
||||
save_optimizer_state,
|
||||
)
|
||||
from .schedulers import (
|
||||
@@ -50,6 +51,7 @@ __all__ = [
|
||||
"VQBeTSchedulerConfig",
|
||||
# State management
|
||||
"load_optimizer_state",
|
||||
"load_optimizer_state_dict",
|
||||
"load_scheduler_state",
|
||||
"save_optimizer_state",
|
||||
"save_scheduler_state",
|
||||
|
||||
@@ -27,7 +27,7 @@ from lerobot.utils.constants import (
|
||||
OPTIMIZER_PARAM_GROUPS,
|
||||
OPTIMIZER_STATE,
|
||||
)
|
||||
from lerobot.utils.io_utils import deserialize_json_into_object, write_json
|
||||
from lerobot.utils.io_utils import deserialize_json_into_object, load_json, write_json
|
||||
from lerobot.utils.utils import flatten_dict, unflatten_dict
|
||||
|
||||
# Type alias for parameters accepted by optimizer build() methods.
|
||||
@@ -281,28 +281,37 @@ class MultiAdamConfig(OptimizerConfig):
|
||||
|
||||
|
||||
def save_optimizer_state(
|
||||
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
||||
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer],
|
||||
save_dir: Path,
|
||||
optim_state_dict: dict | None = None,
|
||||
) -> None:
|
||||
"""Save optimizer state to disk.
|
||||
|
||||
Args:
|
||||
optimizer: Either a single optimizer or a dictionary of optimizers.
|
||||
save_dir: Directory to save the optimizer state.
|
||||
optim_state_dict: Pre-gathered optimizer state dict (for FSDP, where the sharded state must
|
||||
be gathered across ranks first). If provided, it is saved directly instead of calling
|
||||
``optimizer.state_dict()``. Only supported for a single optimizer. Defaults to None.
|
||||
"""
|
||||
if isinstance(optimizer, dict):
|
||||
# Handle dictionary of optimizers
|
||||
if optim_state_dict is not None:
|
||||
raise ValueError("optim_state_dict is not supported for a dict of optimizers")
|
||||
for name, opt in optimizer.items():
|
||||
optimizer_dir = save_dir / name
|
||||
optimizer_dir.mkdir(exist_ok=True, parents=True)
|
||||
_save_single_optimizer_state(opt, optimizer_dir)
|
||||
else:
|
||||
# Handle single optimizer
|
||||
_save_single_optimizer_state(optimizer, save_dir)
|
||||
_save_single_optimizer_state(optimizer, save_dir, optim_state_dict=optim_state_dict)
|
||||
|
||||
|
||||
def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
||||
def _save_single_optimizer_state(
|
||||
optimizer: torch.optim.Optimizer, save_dir: Path, optim_state_dict: dict | None = None
|
||||
) -> None:
|
||||
"""Save a single optimizer's state to disk."""
|
||||
state = optimizer.state_dict()
|
||||
state = dict(optim_state_dict) if optim_state_dict is not None else optimizer.state_dict()
|
||||
param_groups = state.pop("param_groups")
|
||||
flat_state = flatten_dict(state)
|
||||
save_file(flat_state, save_dir / OPTIMIZER_STATE)
|
||||
@@ -356,3 +365,19 @@ def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Pat
|
||||
|
||||
optimizer.load_state_dict(loaded_state_dict)
|
||||
return optimizer
|
||||
|
||||
|
||||
def load_optimizer_state_dict(save_dir: Path) -> dict:
|
||||
"""Read a saved optimizer state dict (safetensors + json) back into a plain dict.
|
||||
|
||||
Unlike `load_optimizer_state`, this does not load into an optimizer and preserves the original
|
||||
``state`` keys verbatim (e.g. FSDP parameter FQNs, which are not integer-castable). It is used by
|
||||
the FSDP resume path, where the full state must be resharded via `FSDP.optim_state_dict_to_load`
|
||||
before being loaded into the (sharded) optimizer.
|
||||
"""
|
||||
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
||||
state = unflatten_dict(flat_state)
|
||||
return {
|
||||
"state": state.get("state", {}),
|
||||
"param_groups": load_json(save_dir / OPTIMIZER_PARAM_GROUPS),
|
||||
}
|
||||
|
||||
@@ -148,7 +148,7 @@ class ACTPolicy(PreTrainedPolicy):
|
||||
l1_loss = (abs_err * valid_mask).sum() / num_valid.clamp_min(1)
|
||||
|
||||
loss_dict = {"l1_loss": l1_loss.item()}
|
||||
if self.config.use_vae:
|
||||
if self.config.use_vae and log_sigma_x2_hat is not None:
|
||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||
# each dimension independently, we sum over the latent dimension to get the total
|
||||
# KL-divergence per batch element, then take the mean over the batch.
|
||||
|
||||
@@ -101,11 +101,23 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
# stack n latest observations from the queue
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
actions = self.diffusion.generate_actions(batch, noise=noise)
|
||||
"""Predict a chunk of actions given environment observations.
|
||||
|
||||
Supports two modes:
|
||||
- Online (queues populated via select_action): stacks observations from internal queues.
|
||||
- Offline (empty queues, e.g. dataloader batch): uses the batch directly.
|
||||
"""
|
||||
queues_populated = any(len(q) > 0 for q in self._queues.values())
|
||||
if queues_populated:
|
||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||
else:
|
||||
batch = dict(batch)
|
||||
if self.config.image_features:
|
||||
for key in self.config.image_features:
|
||||
if batch[key].ndim == 4:
|
||||
batch[key] = batch[key].unsqueeze(1)
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
actions = self.diffusion.generate_actions(batch, noise=noise)
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@@ -252,6 +252,7 @@ class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
def make_pre_post_processors(
|
||||
policy_cfg: PreTrainedConfig,
|
||||
pretrained_path: str | None = None,
|
||||
pretrained_revision: str | None = None,
|
||||
**kwargs: Unpack[ProcessorConfigKwargs],
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
@@ -309,6 +310,7 @@ def make_pre_post_processors(
|
||||
overrides=kwargs.get("preprocessor_overrides", {}),
|
||||
to_transition=batch_to_transition,
|
||||
to_output=transition_to_batch,
|
||||
revision=pretrained_revision,
|
||||
)
|
||||
postprocessor = PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
@@ -318,6 +320,7 @@ def make_pre_post_processors(
|
||||
overrides=kwargs.get("postprocessor_overrides", {}),
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
revision=pretrained_revision,
|
||||
)
|
||||
_reconnect_relative_absolute_steps(preprocessor, postprocessor)
|
||||
return preprocessor, postprocessor
|
||||
@@ -557,6 +560,7 @@ def make_policy(
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
# hyperparameters that we want to vary).
|
||||
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||
kwargs["revision"] = cfg.pretrained_revision
|
||||
policy = policy_cls.from_pretrained(**kwargs)
|
||||
elif cfg.pretrained_path and cfg.use_peft:
|
||||
# Load a pretrained PEFT model on top of the policy. The pretrained path points to the folder/repo
|
||||
|
||||
@@ -23,7 +23,7 @@ from typing import TypedDict, TypeVar, Unpack
|
||||
|
||||
import packaging
|
||||
import safetensors
|
||||
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download
|
||||
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download, save_torch_state_dict
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
||||
@@ -129,10 +129,43 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
if not getattr(cls, "name", None):
|
||||
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: str | Path,
|
||||
*,
|
||||
state_dict: dict[str, Tensor] | None = None,
|
||||
repo_id: str | None = None,
|
||||
push_to_hub: bool = False,
|
||||
card_kwargs: dict | None = None,
|
||||
**push_to_hub_kwargs,
|
||||
) -> str | None:
|
||||
"""Save the policy to a directory (and optionally push to the Hub).
|
||||
|
||||
Overrides `HubMixin.save_pretrained` to add a `state_dict` argument (mirroring
|
||||
`transformers.PreTrainedModel.save_pretrained`). Under FSDP, `self.state_dict()` would
|
||||
return sharded tensors, so the caller gathers the full state dict via a cross-rank
|
||||
collective and passes it here for `_save_pretrained` to write directly.
|
||||
"""
|
||||
save_directory = Path(save_directory)
|
||||
save_directory.mkdir(parents=True, exist_ok=True)
|
||||
self._save_pretrained(save_directory, state_dict=state_dict)
|
||||
if push_to_hub:
|
||||
if repo_id is None:
|
||||
repo_id = save_directory.name
|
||||
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
|
||||
return None
|
||||
|
||||
def _save_pretrained(self, save_directory: Path, state_dict: dict[str, Tensor] | None = None) -> None:
|
||||
self.config._save_pretrained(save_directory)
|
||||
model_to_save = self.module if hasattr(self, "module") else self
|
||||
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
||||
if state_dict is None:
|
||||
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
||||
return
|
||||
# A pre-gathered (e.g. FSDP full) state dict was supplied: write it directly.
|
||||
# `save_torch_state_dict` discards shared-tensor duplicates just like `save_model` does;
|
||||
# pin `max_shard_size` above the total size so the output stays a single `model.safetensors`
|
||||
total_bytes = sum(t.numel() * t.element_size() for t in state_dict.values())
|
||||
save_torch_state_dict(state_dict, str(save_directory), max_shard_size=max(total_bytes, 1))
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
@@ -270,6 +303,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
self,
|
||||
cfg: TrainPipelineConfig,
|
||||
peft_model=None,
|
||||
state_dict: dict[str, Tensor] | None = None,
|
||||
):
|
||||
api = HfApi()
|
||||
repo_id = api.create_repo(
|
||||
@@ -287,7 +321,8 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
peft_model.save_pretrained(saved_path)
|
||||
self.config.save_pretrained(saved_path)
|
||||
else:
|
||||
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
|
||||
# Calls _save_pretrained and stores model tensors
|
||||
self.save_pretrained(saved_path, state_dict=state_dict)
|
||||
|
||||
card = self.generate_model_card(
|
||||
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg
|
||||
|
||||
@@ -124,6 +124,7 @@ def make_reward_model(cfg: RewardModelConfig, **kwargs) -> PreTrainedRewardModel
|
||||
|
||||
if cfg.pretrained_path:
|
||||
kwargs["pretrained_name_or_path"] = cfg.pretrained_path
|
||||
kwargs["revision"] = cfg.pretrained_revision
|
||||
reward_model = reward_cls.from_pretrained(**kwargs)
|
||||
else:
|
||||
reward_model = reward_cls(**kwargs)
|
||||
|
||||
@@ -34,11 +34,14 @@ from torch.optim import Optimizer
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.common.train_utils import (
|
||||
gather_fsdp_state_dicts,
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_fsdp_optimizer_state,
|
||||
load_training_batch_size,
|
||||
load_training_num_processes,
|
||||
load_training_state,
|
||||
push_checkpoint_to_hub,
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
@@ -185,10 +188,16 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
||||
accelerator: Optional Accelerator instance. If None, one will be created automatically.
|
||||
"""
|
||||
if cfg.job.is_remote:
|
||||
from lerobot.jobs import submit_to_hf
|
||||
|
||||
return submit_to_hf(cfg)
|
||||
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("accelerate", extra="training")
|
||||
from accelerate import Accelerator
|
||||
from accelerate.utils import DistributedDataParallelKwargs, DistributedType
|
||||
|
||||
cfg.validate()
|
||||
|
||||
@@ -197,8 +206,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
|
||||
# We set find_unused_parameters=True to handle models with conditional computation
|
||||
if accelerator is None:
|
||||
from accelerate.utils import DistributedDataParallelKwargs
|
||||
|
||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
||||
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
|
||||
@@ -345,6 +352,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=processor_pretrained_path,
|
||||
pretrained_revision=getattr(cfg.policy, "pretrained_revision", None),
|
||||
**processor_kwargs,
|
||||
)
|
||||
|
||||
@@ -370,7 +378,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
||||
if cfg.resume:
|
||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||
# Under FSDP the optimizer state is sharded and must be loaded after `accelerator.prepare()`
|
||||
# (see load_fsdp_optimizer_state below), so skip the optimizer here and load it then.
|
||||
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
|
||||
step, optimizer, lr_scheduler = load_training_state(
|
||||
cfg.checkpoint_path, optimizer, lr_scheduler, load_optimizer=not is_fsdp
|
||||
)
|
||||
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
@@ -460,6 +473,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||
policy, optimizer, dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# FSDP optimizer state is sharded across ranks, so it can only be loaded once the optimizer and
|
||||
# model are FSDP-wrapped (i.e. after `prepare`). Collective: every rank must participate.
|
||||
if cfg.resume and accelerator.distributed_type == DistributedType.FSDP:
|
||||
load_fsdp_optimizer_state(policy, optimizer, cfg.checkpoint_path)
|
||||
|
||||
dl_iter = cycle(dataloader)
|
||||
|
||||
policy.train()
|
||||
@@ -558,6 +577,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
train_tracker.reset_averages()
|
||||
|
||||
if cfg.save_checkpoint and is_saving_step:
|
||||
# Under FSDP, gathering the full model + optimizer state dicts is a cross-rank collective,
|
||||
# so all ranks must participate; rank 0 then writes the materialized dicts. For DDP /
|
||||
# single-GPU the state dicts are saved the normal way inside save_checkpoint.
|
||||
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
|
||||
if is_fsdp:
|
||||
model_state_dict, optim_state_dict = gather_fsdp_state_dicts(policy, optimizer)
|
||||
else:
|
||||
model_state_dict, optim_state_dict = None, None
|
||||
if is_main_process:
|
||||
logging.info(f"Checkpoint policy after step {step}")
|
||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||
@@ -572,8 +599,16 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
postprocessor=postprocessor,
|
||||
num_processes=accelerator.num_processes,
|
||||
batch_size=cfg.batch_size,
|
||||
model_state_dict=model_state_dict,
|
||||
optim_state_dict=optim_state_dict,
|
||||
)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
if cfg.save_checkpoint_to_hub:
|
||||
push_checkpoint_to_hub(
|
||||
checkpoint_dir,
|
||||
cfg.policy.repo_id,
|
||||
private=cfg.policy.private,
|
||||
)
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
|
||||
@@ -634,6 +669,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if eval_env:
|
||||
close_envs(eval_env)
|
||||
|
||||
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
|
||||
model_state_dict = accelerator.get_state_dict(policy) if is_fsdp else None
|
||||
if is_main_process:
|
||||
logging.info("End of training")
|
||||
|
||||
@@ -643,7 +680,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if not cfg.is_reward_model_training and cfg.policy.use_peft:
|
||||
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model)
|
||||
else:
|
||||
unwrapped_model.push_model_to_hub(cfg)
|
||||
unwrapped_model.push_model_to_hub(cfg, state_dict=model_state_dict)
|
||||
preprocessor.push_to_hub(active_cfg.repo_id)
|
||||
postprocessor.push_to_hub(active_cfg.repo_id)
|
||||
|
||||
@@ -652,8 +689,29 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
def _remote_target_in_argv() -> bool:
|
||||
"""True when the CLI requests a remote HF Jobs run (--job.target=<non-local>)."""
|
||||
import sys
|
||||
|
||||
from lerobot.configs.default import JobConfig
|
||||
|
||||
target = None
|
||||
args = sys.argv[1:]
|
||||
for i, tok in enumerate(args):
|
||||
if tok == "--job.target" and i + 1 < len(args):
|
||||
target = args[i + 1]
|
||||
elif tok.startswith("--job.target="):
|
||||
target = tok.split("=", 1)[1]
|
||||
return JobConfig.is_remote_target(target)
|
||||
|
||||
|
||||
def main():
|
||||
register_third_party_plugins()
|
||||
if _remote_target_in_argv():
|
||||
# The policy device is resolved on the remote pod, not here, so silence the
|
||||
# client-side "Device '...' is not available" warning PreTrainedConfig emits
|
||||
# while parsing the config (it fires before train() can dispatch remotely).
|
||||
logging.getLogger("lerobot.configs.policies").setLevel(logging.ERROR)
|
||||
train()
|
||||
|
||||
|
||||
|
||||
@@ -216,9 +216,15 @@ def register_third_party_plugins() -> None:
|
||||
|
||||
This function uses `importlib.metadata` to find packages installed in the environment
|
||||
(including editable installs) starting with 'lerobot_robot_', 'lerobot_camera_',
|
||||
'lerobot_teleoperator_', or 'lerobot_policy_' and imports them.
|
||||
'lerobot_teleoperator_', 'lerobot_policy_', or 'lerobot_env_' and imports them.
|
||||
"""
|
||||
prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_", "lerobot_policy_")
|
||||
prefixes = (
|
||||
"lerobot_robot_",
|
||||
"lerobot_camera_",
|
||||
"lerobot_teleoperator_",
|
||||
"lerobot_policy_",
|
||||
"lerobot_env_",
|
||||
)
|
||||
imported: list[str] = []
|
||||
failed: list[str] = []
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ from lerobot.robots import make_robot_from_config
|
||||
from lerobot.transforms import ImageTransforms, ImageTransformsConfig
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
|
||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_MOTOR_FEATURES, DUMMY_REPO_ID
|
||||
from tests.mocks.mock_robot import MockRobotConfig
|
||||
from tests.utils import require_x86_64_kernel
|
||||
|
||||
@@ -133,6 +133,21 @@ def test_dataset_feature_with_forward_slash_raises_error():
|
||||
)
|
||||
|
||||
|
||||
def test_create_does_not_mutate_input_features(tmp_path, empty_lerobot_dataset_factory):
|
||||
# ``create`` must deep-copy features so a dataset built from another's features stays independent.
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds1", features=DUMMY_MOTOR_FEATURES, use_videos=False
|
||||
)
|
||||
dataset_copy = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "ds2", features=dataset.meta.features, use_videos=False
|
||||
)
|
||||
|
||||
original_shape = dataset.meta.info.features["state"]["shape"]
|
||||
dataset_copy.meta.info.features["state"]["shape"] = (999,)
|
||||
|
||||
assert dataset.meta.info.features["state"]["shape"] == original_shape
|
||||
|
||||
|
||||
def test_add_frame_missing_task(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {"state": {"dtype": "float32", "shape": (1,), "names": None}}
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Importing concrete policy configs registers their draccus `--policy.type`
|
||||
# choices (e.g. "act") so tests can parse them.
|
||||
from lerobot.policies.act.configuration_act import ACTConfig # noqa: F401
|
||||
@@ -0,0 +1,78 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from huggingface_hub.errors import RepositoryNotFoundError
|
||||
|
||||
from lerobot.jobs.dataset import ensure_dataset_available
|
||||
|
||||
|
||||
def _repo_not_found() -> RepositoryNotFoundError:
|
||||
req = httpx.Request("GET", "https://huggingface.co/datasets/test")
|
||||
resp = httpx.Response(404, request=req)
|
||||
return RepositoryNotFoundError("nope", response=resp)
|
||||
|
||||
|
||||
def _api_with_dataset(exists: bool):
|
||||
api = MagicMock()
|
||||
if exists:
|
||||
api.dataset_info.return_value = object()
|
||||
else:
|
||||
api.dataset_info.side_effect = _repo_not_found()
|
||||
return api
|
||||
|
||||
|
||||
def _make_local_cache(tmp_path, repo_id: str) -> None:
|
||||
"""Create the minimal local-cache layout that ensure_dataset_available checks."""
|
||||
info = tmp_path / repo_id / "meta" / "info.json"
|
||||
info.parent.mkdir(parents=True)
|
||||
info.write_text("{}")
|
||||
|
||||
|
||||
# Branch 1: dataset already on Hub → no push, no error (pod downloads by repo_id).
|
||||
def test_dataset_already_on_hub_is_noop():
|
||||
api = _api_with_dataset(True)
|
||||
assert ensure_dataset_available("user/ds", api=api) is None
|
||||
api.dataset_info.assert_called_once_with("user/ds")
|
||||
|
||||
|
||||
# Branch 2: not on Hub but present locally → always push privately.
|
||||
def test_dataset_local_only_uploads_privately(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HF_LEROBOT_HOME", str(tmp_path))
|
||||
_make_local_cache(tmp_path, "user/ds")
|
||||
|
||||
api = _api_with_dataset(False)
|
||||
mock_ds_cls = MagicMock()
|
||||
fake_datasets_module = MagicMock()
|
||||
fake_datasets_module.LeRobotDataset = mock_ds_cls
|
||||
monkeypatch.setitem(sys.modules, "lerobot.datasets", fake_datasets_module)
|
||||
|
||||
assert ensure_dataset_available("user/ds", api=api, tags=["lerobot", "lelab"]) is None
|
||||
|
||||
mock_ds_cls.assert_called_once_with("user/ds")
|
||||
mock_ds_cls.return_value.push_to_hub.assert_called_once_with(private=True, tags=["lerobot", "lelab"])
|
||||
|
||||
|
||||
# Branch 3: not on Hub, NOT in local cache → RuntimeError "neither".
|
||||
def test_dataset_neither_on_hub_nor_local_raises(tmp_path, monkeypatch):
|
||||
monkeypatch.setenv("HF_LEROBOT_HOME", str(tmp_path))
|
||||
# tmp_path is empty — no local cache.
|
||||
|
||||
api = _api_with_dataset(False)
|
||||
with pytest.raises(RuntimeError, match="neither"):
|
||||
ensure_dataset_available("user/ds", api=api)
|
||||
@@ -0,0 +1,426 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime as dt
|
||||
import json
|
||||
import threading
|
||||
from types import SimpleNamespace
|
||||
|
||||
import draccus
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.jobs.hf import (
|
||||
_poll_until_done,
|
||||
build_remote_config_file,
|
||||
build_repo_id,
|
||||
resolve_job_tags,
|
||||
resolve_wandb_api_key,
|
||||
submit_to_hf,
|
||||
)
|
||||
|
||||
|
||||
def test_resolve_job_tags_always_includes_lerobot_and_dedups():
|
||||
assert resolve_job_tags(None) == ["lerobot"]
|
||||
assert resolve_job_tags([]) == ["lerobot"]
|
||||
assert resolve_job_tags(["lelab"]) == ["lerobot", "lelab"]
|
||||
# lerobot isn't duplicated if passed explicitly; order is stable.
|
||||
assert resolve_job_tags(["lelab", "lerobot", "lelab"]) == ["lerobot", "lelab"]
|
||||
|
||||
|
||||
def _fake_inspect(stage_value):
|
||||
return lambda job_id: SimpleNamespace(status=SimpleNamespace(stage=SimpleNamespace(value=stage_value)))
|
||||
|
||||
|
||||
def test_poll_until_done_returns_terminal_stage(monkeypatch):
|
||||
monkeypatch.setattr("huggingface_hub.inspect_job", _fake_inspect("COMPLETED"))
|
||||
done = threading.Event()
|
||||
assert _poll_until_done("j", done, poll_interval=0.01) == "COMPLETED"
|
||||
assert done.is_set()
|
||||
|
||||
|
||||
def test_poll_until_done_exits_when_done_already_set(monkeypatch):
|
||||
# Non-terminal forever; with done pre-set the loop must not block and returns None.
|
||||
monkeypatch.setattr("huggingface_hub.inspect_job", _fake_inspect("RUNNING"))
|
||||
done = threading.Event()
|
||||
done.set()
|
||||
assert _poll_until_done("j", done, poll_interval=0.01) is None
|
||||
|
||||
|
||||
def test_poll_until_done_gives_up_after_repeated_failures(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"huggingface_hub.inspect_job", lambda job_id: (_ for _ in ()).throw(RuntimeError("boom"))
|
||||
)
|
||||
done = threading.Event()
|
||||
result = _poll_until_done("j", done, poll_interval=0.001, max_failures=3)
|
||||
assert result is None
|
||||
assert done.is_set()
|
||||
|
||||
|
||||
def test_resolve_wandb_key_from_env(monkeypatch):
|
||||
monkeypatch.setenv("WANDB_API_KEY", "abc123")
|
||||
assert resolve_wandb_api_key() == "abc123"
|
||||
|
||||
|
||||
def test_resolve_wandb_key_missing(monkeypatch, tmp_path):
|
||||
monkeypatch.delenv("WANDB_API_KEY", raising=False)
|
||||
monkeypatch.setenv("HOME", str(tmp_path)) # no ~/.netrc here
|
||||
monkeypatch.setattr("netrc.netrc", lambda *a, **k: (_ for _ in ()).throw(FileNotFoundError()))
|
||||
assert resolve_wandb_api_key() is None
|
||||
|
||||
|
||||
def test_resolve_wandb_key_from_netrc(monkeypatch):
|
||||
# No env var → fall back to the wandb credentials in ~/.netrc.
|
||||
monkeypatch.delenv("WANDB_API_KEY", raising=False)
|
||||
|
||||
class _FakeNetrc:
|
||||
def authenticators(self, host):
|
||||
assert host == "api.wandb.ai"
|
||||
return ("login", "account", "netrc-secret")
|
||||
|
||||
monkeypatch.setattr("netrc.netrc", lambda *a, **k: _FakeNetrc())
|
||||
assert resolve_wandb_api_key() == "netrc-secret"
|
||||
|
||||
|
||||
def test_resolve_wandb_key_netrc_without_wandb_entry(monkeypatch):
|
||||
# ~/.netrc exists but has no api.wandb.ai entry → None.
|
||||
monkeypatch.delenv("WANDB_API_KEY", raising=False)
|
||||
|
||||
class _FakeNetrc:
|
||||
def authenticators(self, host):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr("netrc.netrc", lambda *a, **k: _FakeNetrc())
|
||||
assert resolve_wandb_api_key() is None
|
||||
|
||||
|
||||
def test_build_repo_id_sanitizes_and_timestamps():
|
||||
now = dt.datetime(2026, 6, 19, 10, 22, 3)
|
||||
assert build_repo_id("alice", "act", now) == "alice/act_2026-06-19_10-22-03"
|
||||
# Runs of illegal characters collapse to a single dash; edges are trimmed.
|
||||
assert build_repo_id("alice", "my cool/run!!", now) == "alice/my-cool-run_2026-06-19_10-22-03"
|
||||
# A name with nothing usable falls back to "train".
|
||||
assert build_repo_id("alice", "///", now) == "alice/train_2026-06-19_10-22-03"
|
||||
|
||||
|
||||
def _minimal_cfg():
|
||||
return draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
|
||||
|
||||
def test_build_remote_config_applies_overrides(tmp_path):
|
||||
cfg = _minimal_cfg()
|
||||
dest = tmp_path / "train_config.json"
|
||||
out = build_remote_config_file(cfg, "u/run", dest)
|
||||
assert out == dest
|
||||
data = json.loads(dest.read_text())
|
||||
# `job` is client-only orchestration and must be stripped for the pod.
|
||||
assert "job" not in data
|
||||
# save_checkpoint_to_hub defaults off → omitted so older images accept the config.
|
||||
assert "save_checkpoint_to_hub" not in data
|
||||
assert data["policy"]["push_to_hub"] is True
|
||||
assert data["policy"]["repo_id"] == "u/run"
|
||||
assert data["policy"]["device"] is None # pod auto-detects its GPU
|
||||
assert data["dataset"]["root"] is None # pod resolves the dataset by repo_id
|
||||
# the caller's cfg must be left untouched (function works on a deep copy)
|
||||
assert cfg.job.target == "a10g-small"
|
||||
assert cfg.save_checkpoint_to_hub is False
|
||||
|
||||
|
||||
def test_build_remote_config_includes_checkpoint_flag_when_enabled(tmp_path):
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=[
|
||||
"--dataset.repo_id",
|
||||
"u/d",
|
||||
"--policy.type",
|
||||
"act",
|
||||
"--job.target",
|
||||
"a10g-small",
|
||||
"--save_checkpoint_to_hub",
|
||||
"true",
|
||||
],
|
||||
)
|
||||
dest = tmp_path / "train_config.json"
|
||||
build_remote_config_file(cfg, "u/run", dest)
|
||||
data = json.loads(dest.read_text())
|
||||
# explicitly enabled → kept in the config (requires a matching trainer image).
|
||||
assert data["save_checkpoint_to_hub"] is True
|
||||
assert "job" not in data
|
||||
|
||||
|
||||
def test_build_remote_config_merges_tags_into_policy(tmp_path):
|
||||
cfg = _minimal_cfg()
|
||||
dest = tmp_path / "train_config.json"
|
||||
build_remote_config_file(cfg, "u/run", dest, tags=["lerobot", "lelab"])
|
||||
data = json.loads(dest.read_text())
|
||||
# tags propagate to the model the pod pushes.
|
||||
assert data["policy"]["tags"] == ["lerobot", "lelab"]
|
||||
|
||||
|
||||
def test_build_remote_config_merges_tags_without_duplicating(tmp_path):
|
||||
cfg = _minimal_cfg()
|
||||
cfg.policy.tags = ["existing", "lerobot"]
|
||||
dest = tmp_path / "train_config.json"
|
||||
build_remote_config_file(cfg, "u/run", dest, tags=["lerobot", "lelab"])
|
||||
data = json.loads(dest.read_text())
|
||||
# pre-existing policy tags are kept; only genuinely-new tags are appended (no dup "lerobot").
|
||||
assert data["policy"]["tags"] == ["existing", "lerobot", "lelab"]
|
||||
|
||||
|
||||
def test_submit_requires_login(monkeypatch):
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: None)
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
with pytest.raises(RuntimeError, match="hf auth login"):
|
||||
submit_to_hf(cfg)
|
||||
|
||||
|
||||
def test_submit_passes_validation_and_submits(monkeypatch):
|
||||
"""Regression: repo_id must be set BEFORE cfg.validate() or validation raises."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
# Patch get_token
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
# Patch HfApi so whoami returns alice
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
|
||||
|
||||
# ensure_dataset_available returns None; patch it out so no Hub access happens
|
||||
# (imported inside submit_to_hf via `from lerobot.jobs.dataset import ensure_dataset_available`).
|
||||
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
|
||||
|
||||
# Patch _stage_config_on_hub to skip network
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf._stage_config_on_hub",
|
||||
lambda cfg, repo_id, token, tags=None: repo_id,
|
||||
)
|
||||
|
||||
# Patch run_job to return a fake job
|
||||
fake_job = MagicMock()
|
||||
fake_job.id = "job-123"
|
||||
run_job_calls = []
|
||||
|
||||
def fake_run_job(**kwargs):
|
||||
run_job_calls.append(kwargs)
|
||||
return fake_job
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "run_job", fake_run_job)
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=[
|
||||
"--dataset.repo_id",
|
||||
"u/d",
|
||||
"--policy.type",
|
||||
"act",
|
||||
"--job.target",
|
||||
"a10g-small",
|
||||
"--job.detach",
|
||||
"true",
|
||||
],
|
||||
)
|
||||
|
||||
# Must NOT raise (pre-fix this raised ValueError about missing repo_id)
|
||||
submit_to_hf(cfg)
|
||||
|
||||
assert len(run_job_calls) == 1, "run_job should have been called exactly once"
|
||||
assert cfg.policy.repo_id is not None
|
||||
assert cfg.policy.repo_id.startswith("alice/")
|
||||
call = run_job_calls[0]
|
||||
# The pod runs `lerobot-train --config_path=<staged repo>` on the requested flavor/image.
|
||||
assert call["command"][0] == "lerobot-train"
|
||||
assert call["command"][1].startswith("--config_path=")
|
||||
assert call["flavor"] == "a10g-small"
|
||||
assert call["image"] == "huggingface/lerobot-gpu:latest"
|
||||
# The Hub token is forwarded so the pod can pull the (possibly private) dataset.
|
||||
assert call["secrets"]["HF_TOKEN"] == "tok"
|
||||
# Every job carries the lerobot tag as a queryable label.
|
||||
assert call["labels"].get("lerobot") == "true"
|
||||
|
||||
|
||||
@pytest.mark.timeout(15)
|
||||
def test_submit_returns_when_job_completes(monkeypatch):
|
||||
"""Non-detach path must RETURN (not hang) once the job reaches a terminal stage."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
|
||||
)
|
||||
monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
|
||||
# Job is already COMPLETED on the first poll.
|
||||
monkeypatch.setattr(
|
||||
"huggingface_hub.inspect_job",
|
||||
lambda job_id: SimpleNamespace(
|
||||
status=SimpleNamespace(stage=SimpleNamespace(value="COMPLETED"), message=None)
|
||||
),
|
||||
)
|
||||
# Log stream ends immediately.
|
||||
monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter(()))
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
# Runs in the pytest main thread (signal handler install requires it); the
|
||||
# @timeout marker fails the test instead of hanging if it regresses.
|
||||
submit_to_hf(cfg)
|
||||
|
||||
|
||||
@pytest.mark.timeout(15)
|
||||
def test_submit_returns_on_model_pushed_marker(monkeypatch):
|
||||
"""Finish when the model-pushed log appears, even if the job stage never flips."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
|
||||
)
|
||||
monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
|
||||
# Job stays RUNNING forever — only the log marker can end the command.
|
||||
monkeypatch.setattr(
|
||||
"huggingface_hub.inspect_job",
|
||||
lambda job_id: SimpleNamespace(
|
||||
status=SimpleNamespace(stage=SimpleNamespace(value="RUNNING"), message=None)
|
||||
),
|
||||
)
|
||||
pushed_line = "INFO Model pushed to https://huggingface.co/alice/myrun"
|
||||
monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter([pushed_line]))
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=[
|
||||
"--dataset.repo_id",
|
||||
"u/d",
|
||||
"--policy.type",
|
||||
"act",
|
||||
"--policy.repo_id",
|
||||
"alice/myrun",
|
||||
"--job.target",
|
||||
"a10g-small",
|
||||
],
|
||||
)
|
||||
# Must return via the model-pushed marker despite the perpetual RUNNING stage.
|
||||
submit_to_hf(cfg)
|
||||
|
||||
|
||||
def test_submit_raises_when_wandb_enabled_without_key(monkeypatch):
|
||||
"""wandb.enable with no key reachable anywhere fails fast, before submitting."""
|
||||
import huggingface_hub
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.hf.resolve_wandb_api_key", lambda: None)
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=[
|
||||
"--dataset.repo_id",
|
||||
"u/d",
|
||||
"--policy.type",
|
||||
"act",
|
||||
"--job.target",
|
||||
"a10g-small",
|
||||
"--wandb.enable",
|
||||
"true",
|
||||
],
|
||||
)
|
||||
with pytest.raises(ValueError, match="WANDB_API_KEY"):
|
||||
submit_to_hf(cfg)
|
||||
|
||||
|
||||
@pytest.mark.timeout(15)
|
||||
def test_submit_raises_when_job_ends_in_error(monkeypatch):
|
||||
"""A terminal non-COMPLETED stage with no model-pushed marker must raise with the status."""
|
||||
from types import SimpleNamespace
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
monkeypatch.setattr("lerobot.jobs.hf.get_token", lambda: "tok")
|
||||
|
||||
class FakeHfApi:
|
||||
def __init__(self, token=None):
|
||||
pass
|
||||
|
||||
def whoami(self, token=None):
|
||||
return {"name": "alice"}
|
||||
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", FakeHfApi)
|
||||
monkeypatch.setattr("lerobot.jobs.dataset.ensure_dataset_available", lambda *a, **kw: None)
|
||||
monkeypatch.setattr(
|
||||
"lerobot.jobs.hf._stage_config_on_hub", lambda cfg, repo_id, token, tags=None: repo_id
|
||||
)
|
||||
monkeypatch.setattr(huggingface_hub, "run_job", lambda **kw: SimpleNamespace(id="job-1", url="http://x"))
|
||||
# Job fails: a terminal ERROR stage carrying the platform's status message.
|
||||
monkeypatch.setattr(
|
||||
"huggingface_hub.inspect_job",
|
||||
lambda job_id: SimpleNamespace(
|
||||
status=SimpleNamespace(stage=SimpleNamespace(value="ERROR"), message="Job timeout")
|
||||
),
|
||||
)
|
||||
# Logs end without the model-pushed marker.
|
||||
monkeypatch.setattr("huggingface_hub.fetch_job_logs", lambda job_id, follow=True: iter(()))
|
||||
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
with pytest.raises(RuntimeError, match=r"stage=ERROR \(Job timeout\)"):
|
||||
submit_to_hf(cfg)
|
||||
@@ -0,0 +1,64 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import draccus
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.default import JobConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
|
||||
|
||||
def test_jobconfig_defaults_are_local():
|
||||
cfg = JobConfig()
|
||||
assert cfg.target is None
|
||||
assert cfg.is_remote is False
|
||||
assert cfg.image == "huggingface/lerobot-gpu:latest"
|
||||
assert cfg.timeout is None
|
||||
assert cfg.detach is False
|
||||
|
||||
|
||||
def test_jobconfig_local_string_is_not_remote():
|
||||
assert JobConfig(target="local").is_remote is False
|
||||
|
||||
|
||||
def test_jobconfig_flavor_is_remote():
|
||||
assert JobConfig(target="a10g-small").is_remote is True
|
||||
|
||||
|
||||
def test_train_config_parses_job_target():
|
||||
parsed = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
assert parsed.job.target == "a10g-small"
|
||||
assert parsed.job.is_remote is True
|
||||
assert parsed.save_checkpoint_to_hub is False
|
||||
|
||||
|
||||
def test_save_checkpoint_to_hub_requires_repo_id():
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=[
|
||||
"--dataset.repo_id",
|
||||
"u/d",
|
||||
"--policy.type",
|
||||
"act",
|
||||
"--policy.push_to_hub",
|
||||
"false",
|
||||
"--save_checkpoint_to_hub",
|
||||
"true",
|
||||
],
|
||||
)
|
||||
with pytest.raises(ValueError, match="requires --policy.repo_id"):
|
||||
cfg.validate()
|
||||
@@ -20,6 +20,7 @@ from lerobot.optim.optimizers import (
|
||||
MultiAdamConfig,
|
||||
SGDConfig,
|
||||
load_optimizer_state,
|
||||
load_optimizer_state_dict,
|
||||
save_optimizer_state,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
@@ -65,6 +66,44 @@ def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path):
|
||||
torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict())
|
||||
|
||||
|
||||
def test_save_and_load_fsdp_optimizer_state_dict_roundtrip(tmp_path):
|
||||
"""The FSDP full optimizer state dict is keyed by parameter FQNs (dotted strings), not the
|
||||
integer indices of the single-GPU path. Verify it survives the safetensors save -> read
|
||||
round-trip used by the FSDP save/resume path (save_optimizer_state(optim_state_dict=...) then
|
||||
load_optimizer_state_dict), which the flatten/unflatten "/" separator must not corrupt."""
|
||||
full_osd = {
|
||||
"state": {
|
||||
"model.layers.0.weight": {
|
||||
"step": torch.tensor(3.0),
|
||||
"exp_avg": torch.randn(4, 4),
|
||||
"exp_avg_sq": torch.randn(4, 4),
|
||||
},
|
||||
"model.layers.0.bias": {
|
||||
"step": torch.tensor(3.0),
|
||||
"exp_avg": torch.randn(4),
|
||||
"exp_avg_sq": torch.randn(4),
|
||||
},
|
||||
},
|
||||
"param_groups": [
|
||||
{"lr": 1e-4, "betas": [0.9, 0.999], "eps": 1e-8, "weight_decay": 0.0, "params": [0, 1]}
|
||||
],
|
||||
}
|
||||
|
||||
save_optimizer_state(
|
||||
torch.optim.Adam([torch.nn.Parameter(torch.randn(1))]), tmp_path, optim_state_dict=full_osd
|
||||
)
|
||||
assert (tmp_path / OPTIMIZER_STATE).is_file()
|
||||
assert (tmp_path / OPTIMIZER_PARAM_GROUPS).is_file()
|
||||
|
||||
loaded = load_optimizer_state_dict(tmp_path)
|
||||
# FQN keys must be preserved verbatim (not int-cast, not split on their dots).
|
||||
assert set(loaded["state"].keys()) == set(full_osd["state"].keys())
|
||||
for fqn, sub in full_osd["state"].items():
|
||||
for k, v in sub.items():
|
||||
torch.testing.assert_close(loaded["state"][fqn][k], v)
|
||||
assert loaded["param_groups"] == full_osd["param_groups"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_params_dict():
|
||||
return {
|
||||
|
||||
@@ -23,6 +23,7 @@ import torch
|
||||
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from packaging import version
|
||||
from safetensors.torch import load_file
|
||||
|
||||
@@ -300,6 +301,29 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
|
||||
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
||||
|
||||
|
||||
def test_save_pretrained_with_state_dict(dummy_dataset_metadata, tmp_path):
|
||||
"""Exercise the FSDP checkpoint path: save_pretrained with a pre-gathered state_dict."""
|
||||
policy_cls = get_policy_class("act")
|
||||
policy_cfg = make_policy_config("act")
|
||||
features = dataset_to_policy_features(dummy_dataset_metadata.features)
|
||||
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
policy_cfg.input_features = {
|
||||
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
|
||||
}
|
||||
policy = policy_cls(policy_cfg)
|
||||
policy.to(policy_cfg.device)
|
||||
|
||||
save_dir = tmp_path / "fsdp_state_dict"
|
||||
policy.save_pretrained(save_dir, state_dict=policy.state_dict())
|
||||
|
||||
# A single, unsharded safetensors file (no sharded set + index).
|
||||
assert (save_dir / SAFETENSORS_SINGLE_FILE).is_file()
|
||||
assert not (save_dir / f"{SAFETENSORS_SINGLE_FILE}.index.json").exists()
|
||||
|
||||
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
|
||||
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("multikey", [True, False])
|
||||
def test_multikey_construction(multikey: bool):
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,67 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
|
||||
import draccus
|
||||
import pytest
|
||||
|
||||
# Importing lerobot_train eagerly pulls in lerobot.datasets, which needs the
|
||||
# `dataset` extra. The base CI tier runs without it, so skip the whole module there.
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.configs.train import TrainPipelineConfig # noqa: E402
|
||||
from lerobot.policies.act.configuration_act import (
|
||||
ACTConfig, # noqa: E402, F401 (registers --policy.type act)
|
||||
)
|
||||
from lerobot.scripts.lerobot_train import _remote_target_in_argv, train # noqa: E402
|
||||
|
||||
|
||||
def _set_argv(monkeypatch, *args):
|
||||
monkeypatch.setattr(sys, "argv", ["lerobot-train", *args])
|
||||
|
||||
|
||||
def test_remote_target_detected_space_separated(monkeypatch):
|
||||
_set_argv(monkeypatch, "--policy.type", "act", "--job.target", "a10g-small")
|
||||
assert _remote_target_in_argv() is True
|
||||
|
||||
|
||||
def test_remote_target_detected_equals(monkeypatch):
|
||||
_set_argv(monkeypatch, "--job.target=t4-small")
|
||||
assert _remote_target_in_argv() is True
|
||||
|
||||
|
||||
def test_local_string_is_not_remote(monkeypatch):
|
||||
_set_argv(monkeypatch, "--job.target", "local")
|
||||
assert _remote_target_in_argv() is False
|
||||
|
||||
|
||||
def test_no_target_is_not_remote(monkeypatch):
|
||||
_set_argv(monkeypatch, "--policy.type", "act")
|
||||
assert _remote_target_in_argv() is False
|
||||
|
||||
|
||||
def test_train_dispatches_to_submit_when_remote(monkeypatch):
|
||||
"""A remote --job.target short-circuits train() to the HF Jobs submitter."""
|
||||
import lerobot.jobs
|
||||
|
||||
captured = []
|
||||
monkeypatch.setattr(lerobot.jobs, "submit_to_hf", lambda cfg: captured.append(cfg) or "submitted")
|
||||
cfg = draccus.parse(
|
||||
TrainPipelineConfig,
|
||||
args=["--dataset.repo_id", "u/d", "--policy.type", "act", "--job.target", "a10g-small"],
|
||||
)
|
||||
# Returns the submitter's result and never enters the local training path.
|
||||
assert train(cfg) == "submitted"
|
||||
assert captured == [cfg]
|
||||
@@ -58,7 +58,46 @@ def download_dataset(repo_id, episodes):
|
||||
print(f"Dataset {repo_id} downloaded successfully")
|
||||
|
||||
|
||||
def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
|
||||
def _write_multi_gpu_config(f, num_processes):
|
||||
f.write("compute_environment: LOCAL_MACHINE\n")
|
||||
f.write("distributed_type: MULTI_GPU\n")
|
||||
f.write("mixed_precision: 'no'\n")
|
||||
f.write(f"num_processes: {num_processes}\n")
|
||||
f.write("use_cpu: false\n")
|
||||
f.write("gpu_ids: all\n")
|
||||
f.write("downcast_bf16: 'no'\n")
|
||||
f.write("machine_rank: 0\n")
|
||||
f.write("main_training_function: main\n")
|
||||
f.write("num_machines: 1\n")
|
||||
f.write("rdzv_backend: static\n")
|
||||
f.write("same_network: true\n")
|
||||
|
||||
|
||||
def _write_fsdp_config(f, num_processes):
|
||||
# FSDP1 with FULL_SHARD (ZeRO-3-equivalent) and FULL_STATE_DICT, matching
|
||||
# docs/source/multi_gpu_training.mdx. ACT's repeated transformer blocks are the wrap units;
|
||||
# fsdp_use_orig_params is required because LeRobot builds the optimizer before prepare().
|
||||
f.write("compute_environment: LOCAL_MACHINE\n")
|
||||
f.write("distributed_type: FSDP\n")
|
||||
f.write("mixed_precision: 'no'\n")
|
||||
f.write(f"num_processes: {num_processes}\n")
|
||||
f.write("use_cpu: false\n")
|
||||
f.write("gpu_ids: all\n")
|
||||
f.write("machine_rank: 0\n")
|
||||
f.write("main_training_function: main\n")
|
||||
f.write("num_machines: 1\n")
|
||||
f.write("rdzv_backend: static\n")
|
||||
f.write("same_network: true\n")
|
||||
f.write("fsdp_config:\n")
|
||||
f.write(" fsdp_version: 1\n")
|
||||
f.write(" fsdp_sharding_strategy: FULL_SHARD\n")
|
||||
f.write(" fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n")
|
||||
f.write(" fsdp_transformer_layer_cls_to_wrap: ACTEncoderLayer,ACTDecoderLayer\n")
|
||||
f.write(" fsdp_use_orig_params: true\n")
|
||||
f.write(" fsdp_state_dict_type: FULL_STATE_DICT\n")
|
||||
|
||||
|
||||
def run_accelerate_training(config_args, num_processes=4, temp_dir=None, distributed_type="MULTI_GPU"):
|
||||
"""
|
||||
Helper function to run training with accelerate launch.
|
||||
|
||||
@@ -66,6 +105,7 @@ def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
|
||||
config_args: List of config arguments to pass to lerobot_train.py
|
||||
num_processes: Number of processes (GPUs) to use
|
||||
temp_dir: Temporary directory for outputs
|
||||
distributed_type: "MULTI_GPU" (DDP) or "FSDP" — selects the generated accelerate config.
|
||||
|
||||
Returns:
|
||||
subprocess.CompletedProcess result
|
||||
@@ -75,18 +115,10 @@ def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
|
||||
|
||||
# Write YAML config
|
||||
with open(config_path, "w") as f:
|
||||
f.write("compute_environment: LOCAL_MACHINE\n")
|
||||
f.write("distributed_type: MULTI_GPU\n")
|
||||
f.write("mixed_precision: 'no'\n")
|
||||
f.write(f"num_processes: {num_processes}\n")
|
||||
f.write("use_cpu: false\n")
|
||||
f.write("gpu_ids: all\n")
|
||||
f.write("downcast_bf16: 'no'\n")
|
||||
f.write("machine_rank: 0\n")
|
||||
f.write("main_training_function: main\n")
|
||||
f.write("num_machines: 1\n")
|
||||
f.write("rdzv_backend: static\n")
|
||||
f.write("same_network: true\n")
|
||||
if distributed_type == "FSDP":
|
||||
_write_fsdp_config(f, num_processes)
|
||||
else:
|
||||
_write_multi_gpu_config(f, num_processes)
|
||||
|
||||
cmd = [
|
||||
"accelerate",
|
||||
@@ -211,3 +243,66 @@ class TestMultiGPUTraining:
|
||||
# Verify optimizer state exists
|
||||
optimizer_state = training_state_dir / "optimizer_state.safetensors"
|
||||
assert optimizer_state.exists(), f"No optimizer state in checkpoint {checkpoint_dir}"
|
||||
|
||||
def test_fsdp_optimizer_save_and_resume(self):
|
||||
"""
|
||||
Test that FSDP saves the (gathered) optimizer state and can resume from it.
|
||||
|
||||
Trains a few steps under FSDP, verifies the gathered optimizer state is written next to the
|
||||
rest of the training state, then resumes from the checkpoint for more steps and checks it
|
||||
completes without shape/key errors in the FSDP optimizer load path.
|
||||
"""
|
||||
# Pre-download dataset to avoid race conditions
|
||||
download_dataset("lerobot/pusht", episodes=[0])
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
output_dir = Path(temp_dir) / "outputs"
|
||||
|
||||
config_args = [
|
||||
"--dataset.repo_id=lerobot/pusht",
|
||||
"--dataset.episodes=[0]",
|
||||
"--policy.type=act",
|
||||
"--policy.device=cuda",
|
||||
"--policy.push_to_hub=false",
|
||||
f"--output_dir={output_dir}",
|
||||
"--batch_size=4",
|
||||
"--steps=10",
|
||||
"--eval_freq=-1",
|
||||
"--log_freq=5",
|
||||
"--save_freq=10",
|
||||
"--seed=42",
|
||||
"--num_workers=0",
|
||||
]
|
||||
|
||||
result = run_accelerate_training(
|
||||
config_args, num_processes=2, temp_dir=temp_dir, distributed_type="FSDP"
|
||||
)
|
||||
assert result.returncode == 0, (
|
||||
f"FSDP training failed:\nSTDOUT:\n{result.stdout}\n\nSTDERR:\n{result.stderr}"
|
||||
)
|
||||
|
||||
# The gathered optimizer state must be written under FSDP (proves the save collective ran),
|
||||
# in the same safetensors format as single-GPU training.
|
||||
training_state_dir = output_dir / "checkpoints" / "last" / "training_state"
|
||||
optimizer_state = training_state_dir / "optimizer_state.safetensors"
|
||||
optimizer_param_groups = training_state_dir / "optimizer_param_groups.json"
|
||||
assert optimizer_state.exists(), f"FSDP optimizer state not saved in {training_state_dir}"
|
||||
assert optimizer_param_groups.exists(), (
|
||||
f"FSDP optimizer param groups not saved in {training_state_dir}"
|
||||
)
|
||||
|
||||
# Resume from the checkpoint for more steps. A successful run proves load_fsdp_optimizer
|
||||
# accepts the saved state and reshards it without shape/key errors.
|
||||
resume_config = output_dir / "checkpoints" / "last" / "pretrained_model" / "train_config.json"
|
||||
resume_args = [
|
||||
f"--config_path={resume_config}",
|
||||
"--resume=true",
|
||||
"--steps=20",
|
||||
]
|
||||
resume_result = run_accelerate_training(
|
||||
resume_args, num_processes=2, temp_dir=temp_dir, distributed_type="FSDP"
|
||||
)
|
||||
assert resume_result.returncode == 0, (
|
||||
f"FSDP resume failed:\nSTDOUT:\n{resume_result.stdout}\n\nSTDERR:\n{resume_result.stderr}"
|
||||
)
|
||||
assert "End of training" in resume_result.stdout or "End of training" in resume_result.stderr
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from lerobot.common.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
@@ -24,6 +24,7 @@ from lerobot.common.train_utils import (
|
||||
load_training_num_processes,
|
||||
load_training_state,
|
||||
load_training_step,
|
||||
push_checkpoint_to_hub,
|
||||
save_checkpoint,
|
||||
save_training_state,
|
||||
save_training_step,
|
||||
@@ -136,3 +137,50 @@ def test_save_load_training_state(tmp_path, optimizer, scheduler):
|
||||
assert loaded_step == 10
|
||||
assert loaded_optimizer is optimizer
|
||||
assert loaded_scheduler is scheduler
|
||||
|
||||
|
||||
def test_load_training_state_skip_optimizer(tmp_path, optimizer, scheduler):
|
||||
# FSDP loads optimizer separately (after accelerator.prepare)
|
||||
# load_training_state(load_optimizer=False) must restore step + scheduler but leave the
|
||||
# optimizer untouched and never touch the on-disk optimizer state.
|
||||
save_training_state(tmp_path, 10, optimizer, scheduler)
|
||||
with patch("lerobot.common.train_utils.load_optimizer_state") as mock_load_optimizer_state:
|
||||
loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(
|
||||
tmp_path, optimizer, scheduler, load_optimizer=False
|
||||
)
|
||||
mock_load_optimizer_state.assert_not_called()
|
||||
assert loaded_step == 10
|
||||
assert loaded_optimizer is optimizer
|
||||
assert loaded_scheduler is scheduler
|
||||
|
||||
|
||||
def test_push_checkpoint_to_hub_creates_repo_and_uploads(tmp_path, monkeypatch):
|
||||
import huggingface_hub
|
||||
|
||||
ckpt = tmp_path / "010000"
|
||||
(ckpt / "pretrained_model").mkdir(parents=True)
|
||||
api = MagicMock()
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", lambda *a, **k: api)
|
||||
push_checkpoint_to_hub(ckpt, "user/run", private=True)
|
||||
api.create_repo.assert_called_once()
|
||||
assert api.create_repo.call_args.kwargs["private"] is True
|
||||
assert api.create_repo.call_args.kwargs["repo_type"] == "model"
|
||||
api.upload_folder.assert_called_once()
|
||||
kwargs = api.upload_folder.call_args.kwargs
|
||||
assert kwargs["repo_id"] == "user/run"
|
||||
assert kwargs["repo_type"] == "model"
|
||||
assert kwargs["path_in_repo"] == "checkpoints/010000"
|
||||
assert kwargs["folder_path"] == str(ckpt)
|
||||
assert kwargs["commit_message"] == "checkpoint 010000"
|
||||
|
||||
|
||||
def test_push_checkpoint_to_hub_defaults_to_hub_default_visibility(tmp_path, monkeypatch):
|
||||
import huggingface_hub
|
||||
|
||||
ckpt = tmp_path / "010000"
|
||||
(ckpt / "pretrained_model").mkdir(parents=True)
|
||||
api = MagicMock()
|
||||
monkeypatch.setattr(huggingface_hub, "HfApi", lambda *a, **k: api)
|
||||
push_checkpoint_to_hub(ckpt, "user/run")
|
||||
api.create_repo.assert_called_once()
|
||||
assert api.create_repo.call_args.kwargs["private"] is None
|
||||
|
||||
Reference in New Issue
Block a user