mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 9449e68725 | |||
| 2b83956eb5 | |||
| 7309790d56 | |||
| 2201401c99 | |||
| 64773e7b22 |
@@ -167,9 +167,9 @@ jobs:
|
|||||||
|
|
||||||
# ── LIBERO TRAIN+EVAL SMOKE ──────────────────────────────────────────────
|
# ── LIBERO TRAIN+EVAL SMOKE ──────────────────────────────────────────────
|
||||||
# Train SmolVLA for 1 step (batch_size=1, dataset episode 0 only) then
|
# Train SmolVLA for 1 step (batch_size=1, dataset episode 0 only) then
|
||||||
# immediately runs eval inside the training loop (eval_freq=1, 1 episode).
|
# immediately runs eval inside the training loop (env_eval_freq=1, 1 episode).
|
||||||
# Tests the full train→eval-within-training pipeline end-to-end.
|
# Tests the full train→eval-within-training pipeline end-to-end.
|
||||||
- name: Run Libero train+eval smoke (1 step, eval_freq=1)
|
- name: Run Libero train+eval smoke (1 step, env_eval_freq=1)
|
||||||
if: env.HF_USER_TOKEN != ''
|
if: env.HF_USER_TOKEN != ''
|
||||||
run: |
|
run: |
|
||||||
docker run --name libero-train-smoke --gpus all \
|
docker run --name libero-train-smoke --gpus all \
|
||||||
@@ -196,7 +196,7 @@ jobs:
|
|||||||
--output_dir=/tmp/train-smoke \
|
--output_dir=/tmp/train-smoke \
|
||||||
--steps=1 \
|
--steps=1 \
|
||||||
--batch_size=1 \
|
--batch_size=1 \
|
||||||
--eval_freq=1 \
|
--env_eval_freq=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.use_async_envs=false \
|
--eval.use_async_envs=false \
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ test-act-ete-train:
|
|||||||
--dataset.episodes="[0]" \
|
--dataset.episodes="[0]" \
|
||||||
--batch_size=2 \
|
--batch_size=2 \
|
||||||
--steps=4 \
|
--steps=4 \
|
||||||
--eval_freq=2 \
|
--env_eval_freq=2 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--save_freq=2 \
|
--save_freq=2 \
|
||||||
@@ -96,7 +96,7 @@ test-diffusion-ete-train:
|
|||||||
--dataset.episodes="[0]" \
|
--dataset.episodes="[0]" \
|
||||||
--batch_size=2 \
|
--batch_size=2 \
|
||||||
--steps=2 \
|
--steps=2 \
|
||||||
--eval_freq=2 \
|
--env_eval_freq=2 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--save_checkpoint=true \
|
--save_checkpoint=true \
|
||||||
@@ -126,7 +126,7 @@ test-tdmpc-ete-train:
|
|||||||
--dataset.episodes="[0]" \
|
--dataset.episodes="[0]" \
|
||||||
--batch_size=2 \
|
--batch_size=2 \
|
||||||
--steps=2 \
|
--steps=2 \
|
||||||
--eval_freq=2 \
|
--env_eval_freq=2 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--save_checkpoint=true \
|
--save_checkpoint=true \
|
||||||
@@ -161,7 +161,7 @@ test-smolvla-ete-train:
|
|||||||
--dataset.episodes="[0]" \
|
--dataset.episodes="[0]" \
|
||||||
--batch_size=2 \
|
--batch_size=2 \
|
||||||
--steps=4 \
|
--steps=4 \
|
||||||
--eval_freq=2 \
|
--env_eval_freq=2 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--save_freq=2 \
|
--save_freq=2 \
|
||||||
|
|||||||
@@ -719,7 +719,7 @@ Example configuration for training the [reward classifier](https://huggingface.c
|
|||||||
"num_workers": 4,
|
"num_workers": 4,
|
||||||
"steps": 5000,
|
"steps": 5000,
|
||||||
"log_freq": 10,
|
"log_freq": 10,
|
||||||
"eval_freq": 1000,
|
"env_eval_freq": 1000,
|
||||||
"save_freq": 1000,
|
"save_freq": 1000,
|
||||||
"save_checkpoint": true,
|
"save_checkpoint": true,
|
||||||
"seed": 2,
|
"seed": 2,
|
||||||
|
|||||||
@@ -143,7 +143,7 @@ lerobot-train \
|
|||||||
--batch_size=4 \
|
--batch_size=4 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval_freq=1000
|
--env_eval_freq=1000
|
||||||
```
|
```
|
||||||
|
|
||||||
## Reproducing published results
|
## Reproducing published results
|
||||||
|
|||||||
@@ -173,7 +173,7 @@ lerobot-train \
|
|||||||
--batch_size=4 \
|
--batch_size=4 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval_freq=1000
|
--env_eval_freq=1000
|
||||||
```
|
```
|
||||||
|
|
||||||
## Relationship to LIBERO
|
## Relationship to LIBERO
|
||||||
|
|||||||
@@ -120,11 +120,11 @@ lerobot-train \
|
|||||||
--batch_size=4 \
|
--batch_size=4 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--eval_freq=1000
|
--env_eval_freq=1000
|
||||||
```
|
```
|
||||||
|
|
||||||
## Practical tips
|
## Practical tips
|
||||||
|
|
||||||
- Use the one-hot task conditioning for multi-task training (MT10/MT50 conventions) so policies have explicit task context.
|
- Use the one-hot task conditioning for multi-task training (MT10/MT50 conventions) so policies have explicit task context.
|
||||||
- Inspect the dataset task descriptions and the `info["is_success"]` keys when writing post-processing or logging so your success metrics line up with the benchmark.
|
- Inspect the dataset task descriptions and the `info["is_success"]` keys when writing post-processing or logging so your success metrics line up with the benchmark.
|
||||||
- Adjust `batch_size`, `steps`, and `eval_freq` to match your compute budget.
|
- Adjust `batch_size`, `steps`, and `env_eval_freq` to match your compute budget.
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ accelerate launch \
|
|||||||
--batch_size=32 \
|
--batch_size=32 \
|
||||||
--num_workers=4 \
|
--num_workers=4 \
|
||||||
--log_freq=20 \
|
--log_freq=20 \
|
||||||
--eval_freq=-1 \
|
--env_eval_freq=-1 \
|
||||||
--save_checkpoint=true \
|
--save_checkpoint=true \
|
||||||
--save_freq=2000
|
--save_freq=2000
|
||||||
```
|
```
|
||||||
@@ -142,7 +142,7 @@ accelerate launch \
|
|||||||
--batch_size=32 \
|
--batch_size=32 \
|
||||||
--num_workers=4 \
|
--num_workers=4 \
|
||||||
--log_freq=20 \
|
--log_freq=20 \
|
||||||
--eval_freq=-1 \
|
--env_eval_freq=-1 \
|
||||||
--save_checkpoint=true \
|
--save_checkpoint=true \
|
||||||
--save_freq=2000
|
--save_freq=2000
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -113,61 +113,6 @@ accelerate launch --num_processes=2 $(which lerobot-train) \
|
|||||||
--policy=act
|
--policy=act
|
||||||
```
|
```
|
||||||
|
|
||||||
## Training Large Models with FSDP
|
|
||||||
|
|
||||||
DDP replicates the full model on every GPU, so a model that doesn't fit on one GPU won't fit under
|
|
||||||
DDP either. For large models, use **FSDP** (Fully Sharded Data Parallel), which shards parameters,
|
|
||||||
gradients, and optimizer state across GPUs. See the [accelerate FSDP guide](https://huggingface.co/docs/accelerate/usage_guides/fsdp) for background.
|
|
||||||
|
|
||||||
An example on how to launch LeRobot training with FSDP across 4 GPUs (1 machine):
|
|
||||||
|
|
||||||
```bash
|
|
||||||
accelerate launch --config_file fsdp.yaml --num_processes=4 $(which lerobot-train) \
|
|
||||||
--dataset.repo_id=${HF_USER}/my_dataset \
|
|
||||||
--policy.type=<your_policy> \
|
|
||||||
--output_dir=outputs/train/my_policy_fsdp
|
|
||||||
```
|
|
||||||
|
|
||||||
A minimal `fsdp.yaml` (FSDP1; shards params/grads/optimizer — ZeRO-3-equivalent):
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
compute_environment: LOCAL_MACHINE
|
|
||||||
distributed_type: FSDP
|
|
||||||
mixed_precision: bf16
|
|
||||||
num_machines: 1
|
|
||||||
num_processes: 4
|
|
||||||
fsdp_config:
|
|
||||||
fsdp_version: 1
|
|
||||||
fsdp_sharding_strategy: FULL_SHARD # params + grads + optimizer (ZeRO-3)
|
|
||||||
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
|
|
||||||
fsdp_transformer_layer_cls_to_wrap: <YourTransformerBlock> # repeated block class to shard
|
|
||||||
fsdp_use_orig_params: true # required: optimizer is built pre-prepare
|
|
||||||
fsdp_state_dict_type: FULL_STATE_DICT
|
|
||||||
```
|
|
||||||
|
|
||||||
Set `fsdp_transformer_layer_cls_to_wrap` to your model's repeated transformer-block class so each
|
|
||||||
block is sharded as its own unit. `fsdp_use_orig_params: true` is required because LeRobot builds the
|
|
||||||
optimizer before `accelerator.prepare()`.
|
|
||||||
|
|
||||||
### FSDP checkpoints
|
|
||||||
|
|
||||||
LeRobot gathers the full state dict across all ranks and the main process writes it as a single
|
|
||||||
`model.safetensors`, loadable as usual with `Policy.from_pretrained(...)`. Two things to look out for:
|
|
||||||
|
|
||||||
- **Checkpoints store fp32 weights.** Under mixed precision (`bf16`/`fp16`) FSDP keeps an fp32 master
|
|
||||||
copy, and the checkpoint saves it (~2× the bf16 size on disk) so training can resume consistently
|
|
||||||
with the fp32 optimizer state; `from_pretrained` casts back to the policy dtype on load. FSDP-specific
|
|
||||||
caveat: an fp32 checkpoint is materialized in full precision on the target device _before_ casting,
|
|
||||||
so loading it for inference on a tight GPU can OOM even when the bf16 model would fit — load on CPU
|
|
||||||
first, or cast `model.safetensors` to the deployment dtype offline.
|
|
||||||
- The sharded optimizer state is gathered into a full (world-size-independent) state dict and saved
|
|
||||||
alongside the model in the same `optimizer_state.safetensors` / `optimizer_param_groups.json`
|
|
||||||
format as single-GPU training, so **resume-from-checkpoint is supported** with `--resume=true`.
|
|
||||||
Resume reshards both the model and the optimizer state to the _current_ FSDP topology, so you can
|
|
||||||
resume an FSDP checkpoint on a different number of GPUs. Note that the data sampler is only
|
|
||||||
sample-exact when the world size and batch size match the original run (a warning is logged
|
|
||||||
otherwise); the optimizer/model state itself is unaffected.
|
|
||||||
|
|
||||||
## Notes
|
## Notes
|
||||||
|
|
||||||
- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
|
- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
|
||||||
|
|||||||
@@ -314,7 +314,7 @@ lerobot-train \
|
|||||||
--steps=30000 \
|
--steps=30000 \
|
||||||
--save_freq=1000 \
|
--save_freq=1000 \
|
||||||
--log_freq=100 \
|
--log_freq=100 \
|
||||||
--eval_freq=1000 \
|
--env_eval_freq=1000 \
|
||||||
--policy.type=multi_task_dit \
|
--policy.type=multi_task_dit \
|
||||||
--policy.device=cuda \
|
--policy.device=cuda \
|
||||||
--policy.horizon=32 \
|
--policy.horizon=32 \
|
||||||
|
|||||||
@@ -166,7 +166,7 @@ lerobot-train \
|
|||||||
--output_dir=./outputs/smolvla_robocasa_CloseFridge \
|
--output_dir=./outputs/smolvla_robocasa_CloseFridge \
|
||||||
--steps=100000 \
|
--steps=100000 \
|
||||||
--batch_size=4 \
|
--batch_size=4 \
|
||||||
--eval_freq=5000 \
|
--env_eval_freq=5000 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=5 \
|
--eval.n_episodes=5 \
|
||||||
--save_freq=10000
|
--save_freq=10000
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ lerobot-train \
|
|||||||
--output_dir=./outputs/smolvla_vlabench_primitive \
|
--output_dir=./outputs/smolvla_vlabench_primitive \
|
||||||
--steps=100000 \
|
--steps=100000 \
|
||||||
--batch_size=4 \
|
--batch_size=4 \
|
||||||
--eval_freq=5000 \
|
--env_eval_freq=5000 \
|
||||||
--eval.batch_size=1 \
|
--eval.batch_size=1 \
|
||||||
--eval.n_episodes=1 \
|
--eval.n_episodes=1 \
|
||||||
--save_freq=10000
|
--save_freq=10000
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ from torch.optim.lr_scheduler import LRScheduler
|
|||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.optim import (
|
from lerobot.optim import (
|
||||||
load_optimizer_state,
|
load_optimizer_state,
|
||||||
load_optimizer_state_dict,
|
|
||||||
load_scheduler_state,
|
load_scheduler_state,
|
||||||
save_optimizer_state,
|
save_optimizer_state,
|
||||||
save_scheduler_state,
|
save_scheduler_state,
|
||||||
@@ -99,8 +98,6 @@ def save_checkpoint(
|
|||||||
postprocessor: PolicyProcessorPipeline | None = None,
|
postprocessor: PolicyProcessorPipeline | None = None,
|
||||||
num_processes: int | None = None,
|
num_processes: int | None = None,
|
||||||
batch_size: int | None = None,
|
batch_size: int | None = None,
|
||||||
model_state_dict: dict | None = None,
|
|
||||||
optim_state_dict: dict | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""This function creates the following directory structure:
|
"""This function creates the following directory structure:
|
||||||
|
|
||||||
@@ -130,18 +127,9 @@ def save_checkpoint(
|
|||||||
resume. Defaults to None (not recorded).
|
resume. Defaults to None (not recorded).
|
||||||
batch_size (int | None, optional): Per-process batch size to record for sample-exact
|
batch_size (int | None, optional): Per-process batch size to record for sample-exact
|
||||||
resume. Defaults to None (not recorded).
|
resume. Defaults to None (not recorded).
|
||||||
model_state_dict: Pre-gathered full (unsharded) model state dict. Required under FSDP,
|
|
||||||
where `policy.state_dict()` would return sharded tensors; the caller gathers it via a
|
|
||||||
cross-rank collective and passes it here so rank 0 can write it directly. It holds
|
|
||||||
FSDP's fp32 master weights and is saved as-is (the loader casts to the policy dtype on
|
|
||||||
read). When None (DDP / single-GPU), the model is saved the normal way. Defaults to None.
|
|
||||||
optim_state_dict: Pre-gathered full (unsharded) optimizer state dict. Required under FSDP
|
|
||||||
(gathered alongside `model_state_dict` via `gather_fsdp_state_dicts`); saved in the same
|
|
||||||
safetensors format as the single-GPU path. When None, `optimizer.state_dict()` is used.
|
|
||||||
Defaults to None.
|
|
||||||
"""
|
"""
|
||||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||||
policy.save_pretrained(pretrained_dir, state_dict=model_state_dict)
|
policy.save_pretrained(pretrained_dir)
|
||||||
cfg.save_pretrained(pretrained_dir)
|
cfg.save_pretrained(pretrained_dir)
|
||||||
if cfg.peft is not None:
|
if cfg.peft is not None:
|
||||||
# When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the
|
# When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the
|
||||||
@@ -152,13 +140,7 @@ def save_checkpoint(
|
|||||||
if postprocessor is not None:
|
if postprocessor is not None:
|
||||||
postprocessor.save_pretrained(pretrained_dir)
|
postprocessor.save_pretrained(pretrained_dir)
|
||||||
save_training_state(
|
save_training_state(
|
||||||
checkpoint_dir,
|
checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes, batch_size=batch_size
|
||||||
step,
|
|
||||||
optimizer,
|
|
||||||
scheduler,
|
|
||||||
num_processes=num_processes,
|
|
||||||
batch_size=batch_size,
|
|
||||||
optim_state_dict=optim_state_dict,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -169,7 +151,6 @@ def save_training_state(
|
|||||||
scheduler: LRScheduler | None = None,
|
scheduler: LRScheduler | None = None,
|
||||||
num_processes: int | None = None,
|
num_processes: int | None = None,
|
||||||
batch_size: int | None = None,
|
batch_size: int | None = None,
|
||||||
optim_state_dict: dict | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Saves the training step, optimizer state, scheduler state, and rng state.
|
Saves the training step, optimizer state, scheduler state, and rng state.
|
||||||
@@ -183,21 +164,19 @@ def save_training_state(
|
|||||||
Defaults to None.
|
Defaults to None.
|
||||||
num_processes (int | None, optional): Distributed world size to record. Defaults to None.
|
num_processes (int | None, optional): Distributed world size to record. Defaults to None.
|
||||||
batch_size (int | None, optional): Per-process batch size to record. Defaults to None.
|
batch_size (int | None, optional): Per-process batch size to record. Defaults to None.
|
||||||
optim_state_dict: Pre-gathered full optimizer state dict (for FSDP). Saved instead of
|
|
||||||
`optimizer.state_dict()` when provided. Defaults to None.
|
|
||||||
"""
|
"""
|
||||||
save_dir = checkpoint_dir / TRAINING_STATE_DIR
|
save_dir = checkpoint_dir / TRAINING_STATE_DIR
|
||||||
save_dir.mkdir(parents=True, exist_ok=True)
|
save_dir.mkdir(parents=True, exist_ok=True)
|
||||||
save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size)
|
save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size)
|
||||||
save_rng_state(save_dir)
|
save_rng_state(save_dir)
|
||||||
if optimizer is not None:
|
if optimizer is not None:
|
||||||
save_optimizer_state(optimizer, save_dir, optim_state_dict=optim_state_dict)
|
save_optimizer_state(optimizer, save_dir)
|
||||||
if scheduler is not None:
|
if scheduler is not None:
|
||||||
save_scheduler_state(scheduler, save_dir)
|
save_scheduler_state(scheduler, save_dir)
|
||||||
|
|
||||||
|
|
||||||
def load_training_state(
|
def load_training_state(
|
||||||
checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None, load_optimizer: bool = True
|
checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None
|
||||||
) -> tuple[int, Optimizer, LRScheduler | None]:
|
) -> tuple[int, Optimizer, LRScheduler | None]:
|
||||||
"""
|
"""
|
||||||
Loads the training step, optimizer state, scheduler state, and rng state.
|
Loads the training step, optimizer state, scheduler state, and rng state.
|
||||||
@@ -207,10 +186,6 @@ def load_training_state(
|
|||||||
checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir.
|
checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir.
|
||||||
optimizer (Optimizer): The optimizer to load the state_dict to.
|
optimizer (Optimizer): The optimizer to load the state_dict to.
|
||||||
scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None).
|
scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None).
|
||||||
load_optimizer (bool, optional): Whether to load the optimizer state from disk. Defaults to
|
|
||||||
True. Set to False under FSDP, where the sharded optimizer state must be loaded after
|
|
||||||
`accelerator.prepare()` via `load_fsdp_optimizer_state` (the optimizer is returned
|
|
||||||
untouched here).
|
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir
|
NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir
|
||||||
@@ -225,61 +200,8 @@ def load_training_state(
|
|||||||
|
|
||||||
load_rng_state(training_state_dir)
|
load_rng_state(training_state_dir)
|
||||||
step = load_training_step(training_state_dir)
|
step = load_training_step(training_state_dir)
|
||||||
if load_optimizer:
|
optimizer = load_optimizer_state(optimizer, training_state_dir)
|
||||||
optimizer = load_optimizer_state(optimizer, training_state_dir)
|
|
||||||
if scheduler is not None:
|
if scheduler is not None:
|
||||||
scheduler = load_scheduler_state(scheduler, training_state_dir)
|
scheduler = load_scheduler_state(scheduler, training_state_dir)
|
||||||
|
|
||||||
return step, optimizer, scheduler
|
return step, optimizer, scheduler
|
||||||
|
|
||||||
|
|
||||||
def gather_fsdp_state_dicts(model, optimizer) -> tuple[dict, dict]:
|
|
||||||
"""Gather the full (unsharded) model and optimizer state dicts under FSDP.
|
|
||||||
|
|
||||||
`model.state_dict()` and `FSDP.optim_state_dict(...)` are cross-rank collectives, so this must be
|
|
||||||
called on *every* rank with the prepared (FSDP-wrapped) `model` and `optimizer`. With
|
|
||||||
`rank0_only=True` and `offload_to_cpu=True`, every rank runs the all-gather but only rank 0
|
|
||||||
materializes the full dicts (the others get empty dicts) and they are kept on CPU to bound GPU
|
|
||||||
memory. The returned optimizer state dict is keyed by parameter FQNs and is world-size
|
|
||||||
independent; `load_fsdp_optimizer_state` reshards it on resume.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(model_state_dict, optim_state_dict): full dicts on rank 0, empty dicts on other ranks.
|
|
||||||
"""
|
|
||||||
from torch.distributed.fsdp import (
|
|
||||||
FullOptimStateDictConfig,
|
|
||||||
FullStateDictConfig,
|
|
||||||
FullyShardedDataParallel as FSDP, # noqa F401
|
|
||||||
StateDictType,
|
|
||||||
)
|
|
||||||
|
|
||||||
state_cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
|
||||||
optim_cfg = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
|
||||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, state_cfg, optim_cfg):
|
|
||||||
model_state_dict = model.state_dict()
|
|
||||||
optim_state_dict = FSDP.optim_state_dict(model, optimizer)
|
|
||||||
return model_state_dict, optim_state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def load_fsdp_optimizer_state(model, optimizer, checkpoint_dir: Path) -> None:
|
|
||||||
"""Load the FSDP optimizer state (saved as safetensors) and reshard it into the optimizer.
|
|
||||||
|
|
||||||
This is a cross-rank collective and must be called on every rank *after* `accelerator.prepare()`
|
|
||||||
with the prepared (FSDP-wrapped) `model` and `optimizer`. The saved state is the full,
|
|
||||||
world-size-independent optimizer state (keyed by parameter FQNs); `FSDP.optim_state_dict_to_load`
|
|
||||||
reshards it to the current FSDP topology, so resume on a different number of GPUs works.
|
|
||||||
"""
|
|
||||||
from torch.distributed.fsdp import (
|
|
||||||
FullOptimStateDictConfig,
|
|
||||||
FullStateDictConfig,
|
|
||||||
FullyShardedDataParallel as FSDP, # noqa F401
|
|
||||||
StateDictType,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Every rank reads the same full state from the (shared) checkpoint dir, so rank0_only=False.
|
|
||||||
full_osd = load_optimizer_state_dict(checkpoint_dir / TRAINING_STATE_DIR)
|
|
||||||
state_cfg = FullStateDictConfig(rank0_only=False)
|
|
||||||
optim_cfg = FullOptimStateDictConfig(rank0_only=False)
|
|
||||||
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, state_cfg, optim_cfg):
|
|
||||||
sharded_osd = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=full_osd)
|
|
||||||
optimizer.load_state_dict(sharded_osd)
|
|
||||||
|
|||||||
@@ -39,6 +39,8 @@ class DatasetConfig:
|
|||||||
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||||
return_uint8: bool = False
|
return_uint8: bool = False
|
||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
|
# Fraction of episodes held out per task for offline evaluation (0.0 = disabled).
|
||||||
|
eval_split: float = 0.0
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self.episodes is not None:
|
if self.episodes is not None:
|
||||||
|
|||||||
@@ -100,8 +100,13 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
prefetch_factor: int = 4
|
prefetch_factor: int = 4
|
||||||
persistent_workers: bool = True
|
persistent_workers: bool = True
|
||||||
steps: int = 100_000
|
steps: int = 100_000
|
||||||
eval_freq: int = 20_000
|
# Run policy in the simulation environment every N steps to measure reward/success (0 = disabled).
|
||||||
|
env_eval_freq: int = 20_000
|
||||||
log_freq: int = 200
|
log_freq: int = 200
|
||||||
|
# Compute eval loss on held-out episodes every N steps (0 = disabled). Requires eval_split > 0.
|
||||||
|
eval_steps: int = 0
|
||||||
|
# Cap on total eval samples, split uniformly across tasks (0 = use all held-out data).
|
||||||
|
max_eval_samples: int = 0
|
||||||
tolerance_s: float = 1e-4
|
tolerance_s: float = 1e-4
|
||||||
save_checkpoint: bool = True
|
save_checkpoint: bool = True
|
||||||
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ from .dataset_tools import (
|
|||||||
remove_feature,
|
remove_feature,
|
||||||
split_dataset,
|
split_dataset,
|
||||||
)
|
)
|
||||||
from .factory import make_dataset, resolve_delta_timestamps
|
from .factory import make_dataset, make_train_eval_datasets, resolve_delta_timestamps
|
||||||
from .image_writer import safe_stop_image_writer
|
from .image_writer import safe_stop_image_writer
|
||||||
from .io_utils import load_episodes, write_stats
|
from .io_utils import load_episodes, write_stats
|
||||||
from .language import (
|
from .language import (
|
||||||
@@ -89,6 +89,7 @@ __all__ = [
|
|||||||
"get_feature_stats",
|
"get_feature_stats",
|
||||||
"load_episodes",
|
"load_episodes",
|
||||||
"make_dataset",
|
"make_dataset",
|
||||||
|
"make_train_eval_datasets",
|
||||||
"merge_datasets",
|
"merge_datasets",
|
||||||
"modify_features",
|
"modify_features",
|
||||||
"modify_tasks",
|
"modify_tasks",
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -130,3 +131,81 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
|||||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def make_train_eval_datasets(
|
||||||
|
cfg: TrainPipelineConfig,
|
||||||
|
) -> tuple[LeRobotDataset | MultiLeRobotDataset, LeRobotDataset | None]:
|
||||||
|
"""Create train and optional eval datasets by splitting episodes based on eval_split.
|
||||||
|
|
||||||
|
The last ceil(n_episodes * eval_split) episodes per task are held out for evaluation.
|
||||||
|
If eval_split == 0.0, returns (full_dataset, None).
|
||||||
|
"""
|
||||||
|
full_dataset = make_dataset(cfg)
|
||||||
|
|
||||||
|
if cfg.dataset.eval_split == 0.0:
|
||||||
|
return full_dataset, None
|
||||||
|
|
||||||
|
base_episodes = (
|
||||||
|
full_dataset.episodes if full_dataset.episodes is not None else list(range(full_dataset.num_episodes))
|
||||||
|
)
|
||||||
|
|
||||||
|
episode_tasks = full_dataset.meta.episodes["tasks"]
|
||||||
|
task_to_episodes: dict[str, list[int]] = {}
|
||||||
|
for ep_idx in base_episodes:
|
||||||
|
task_key = episode_tasks[ep_idx][0] if episode_tasks[ep_idx] else ""
|
||||||
|
task_to_episodes.setdefault(task_key, []).append(ep_idx)
|
||||||
|
|
||||||
|
train_episodes, eval_episodes = [], []
|
||||||
|
for eps in task_to_episodes.values():
|
||||||
|
n_eval = math.ceil(len(eps) * cfg.dataset.eval_split)
|
||||||
|
train_episodes.extend(eps[: len(eps) - n_eval])
|
||||||
|
eval_episodes.extend(eps[len(eps) - n_eval :])
|
||||||
|
|
||||||
|
if not train_episodes:
|
||||||
|
raise ValueError(
|
||||||
|
f"eval_split={cfg.dataset.eval_split} leaves 0 training episodes from {len(base_episodes)} total."
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info(
|
||||||
|
f"Train/eval split: {len(train_episodes)} train, {len(eval_episodes)} eval "
|
||||||
|
f"(eval_split={cfg.dataset.eval_split}, {len(task_to_episodes)} tasks)"
|
||||||
|
)
|
||||||
|
|
||||||
|
delta_timestamps = resolve_delta_timestamps(cfg.trainable_config, full_dataset.meta)
|
||||||
|
|
||||||
|
train_image_transforms = (
|
||||||
|
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset = LeRobotDataset(
|
||||||
|
cfg.dataset.repo_id,
|
||||||
|
root=cfg.dataset.root,
|
||||||
|
episodes=train_episodes,
|
||||||
|
delta_timestamps=delta_timestamps,
|
||||||
|
image_transforms=train_image_transforms,
|
||||||
|
revision=cfg.dataset.revision,
|
||||||
|
video_backend=cfg.dataset.video_backend,
|
||||||
|
return_uint8=True,
|
||||||
|
tolerance_s=cfg.tolerance_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_dataset = LeRobotDataset(
|
||||||
|
cfg.dataset.repo_id,
|
||||||
|
root=cfg.dataset.root,
|
||||||
|
episodes=eval_episodes,
|
||||||
|
delta_timestamps=delta_timestamps,
|
||||||
|
image_transforms=None,
|
||||||
|
revision=cfg.dataset.revision,
|
||||||
|
video_backend=cfg.dataset.video_backend,
|
||||||
|
return_uint8=True,
|
||||||
|
tolerance_s=cfg.tolerance_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.dataset.use_imagenet_stats:
|
||||||
|
for ds in (train_dataset, eval_dataset):
|
||||||
|
for key in ds.meta.camera_keys:
|
||||||
|
for stats_type, stats in IMAGENET_STATS.items():
|
||||||
|
ds.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||||
|
|
||||||
|
return train_dataset, eval_dataset
|
||||||
|
|||||||
@@ -370,6 +370,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self.reader.load_and_activate()
|
self.reader.load_and_activate()
|
||||||
return self.reader.hf_dataset
|
return self.reader.hf_dataset
|
||||||
|
|
||||||
|
@property
|
||||||
|
def absolute_to_relative_idx(self) -> dict[int, int] | None:
|
||||||
|
"""Mapping from absolute frame indices to HF dataset row positions.
|
||||||
|
|
||||||
|
Non-None only for episode-filtered datasets where absolute indices
|
||||||
|
(from metadata) differ from row positions in the loaded HF dataset.
|
||||||
|
"""
|
||||||
|
reader = self._ensure_reader()
|
||||||
|
if reader.hf_dataset is None:
|
||||||
|
reader.load_and_activate()
|
||||||
|
return reader._absolute_to_relative_idx
|
||||||
|
|
||||||
# ── Writer-delegated methods ──────────────────────────────────────
|
# ── Writer-delegated methods ──────────────────────────────────────
|
||||||
|
|
||||||
def add_frame(self, frame: dict) -> None:
|
def add_frame(self, frame: dict) -> None:
|
||||||
|
|||||||
@@ -53,6 +53,7 @@ class EpisodeAwareSampler:
|
|||||||
drop_n_last_frames: int = 0,
|
drop_n_last_frames: int = 0,
|
||||||
shuffle: bool = False,
|
shuffle: bool = False,
|
||||||
seed: int = 0,
|
seed: int = 0,
|
||||||
|
absolute_to_relative_idx: dict[int, int] | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -107,6 +108,7 @@ class EpisodeAwareSampler:
|
|||||||
self.seed = seed
|
self.seed = seed
|
||||||
self._epoch = 0
|
self._epoch = 0
|
||||||
self._start_index = 0
|
self._start_index = 0
|
||||||
|
self._absolute_to_relative = absolute_to_relative_idx
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def indices(self) -> list[int]:
|
def indices(self) -> list[int]:
|
||||||
@@ -132,7 +134,10 @@ class EpisodeAwareSampler:
|
|||||||
def _frame_index(self, position: int) -> int:
|
def _frame_index(self, position: int) -> int:
|
||||||
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
|
episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
|
||||||
position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0)
|
position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0)
|
||||||
return int(self._starts[episode]) + position_in_episode
|
absolute_idx = int(self._starts[episode]) + position_in_episode
|
||||||
|
if self._absolute_to_relative is not None:
|
||||||
|
return self._absolute_to_relative[absolute_idx]
|
||||||
|
return absolute_idx
|
||||||
|
|
||||||
def __iter__(self) -> Iterator[int]:
|
def __iter__(self) -> Iterator[int]:
|
||||||
# Advance epoch state eagerly, not on first consumption of the generator.
|
# Advance epoch state eagerly, not on first consumption of the generator.
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from .optimizers import (
|
|||||||
SGDConfig as SGDConfig,
|
SGDConfig as SGDConfig,
|
||||||
XVLAAdamWConfig as XVLAAdamWConfig,
|
XVLAAdamWConfig as XVLAAdamWConfig,
|
||||||
load_optimizer_state,
|
load_optimizer_state,
|
||||||
load_optimizer_state_dict,
|
|
||||||
save_optimizer_state,
|
save_optimizer_state,
|
||||||
)
|
)
|
||||||
from .schedulers import (
|
from .schedulers import (
|
||||||
@@ -51,7 +50,6 @@ __all__ = [
|
|||||||
"VQBeTSchedulerConfig",
|
"VQBeTSchedulerConfig",
|
||||||
# State management
|
# State management
|
||||||
"load_optimizer_state",
|
"load_optimizer_state",
|
||||||
"load_optimizer_state_dict",
|
|
||||||
"load_scheduler_state",
|
"load_scheduler_state",
|
||||||
"save_optimizer_state",
|
"save_optimizer_state",
|
||||||
"save_scheduler_state",
|
"save_scheduler_state",
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ from lerobot.utils.constants import (
|
|||||||
OPTIMIZER_PARAM_GROUPS,
|
OPTIMIZER_PARAM_GROUPS,
|
||||||
OPTIMIZER_STATE,
|
OPTIMIZER_STATE,
|
||||||
)
|
)
|
||||||
from lerobot.utils.io_utils import deserialize_json_into_object, load_json, write_json
|
from lerobot.utils.io_utils import deserialize_json_into_object, write_json
|
||||||
from lerobot.utils.utils import flatten_dict, unflatten_dict
|
from lerobot.utils.utils import flatten_dict, unflatten_dict
|
||||||
|
|
||||||
# Type alias for parameters accepted by optimizer build() methods.
|
# Type alias for parameters accepted by optimizer build() methods.
|
||||||
@@ -281,37 +281,28 @@ class MultiAdamConfig(OptimizerConfig):
|
|||||||
|
|
||||||
|
|
||||||
def save_optimizer_state(
|
def save_optimizer_state(
|
||||||
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer],
|
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
|
||||||
save_dir: Path,
|
|
||||||
optim_state_dict: dict | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Save optimizer state to disk.
|
"""Save optimizer state to disk.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer: Either a single optimizer or a dictionary of optimizers.
|
optimizer: Either a single optimizer or a dictionary of optimizers.
|
||||||
save_dir: Directory to save the optimizer state.
|
save_dir: Directory to save the optimizer state.
|
||||||
optim_state_dict: Pre-gathered optimizer state dict (for FSDP, where the sharded state must
|
|
||||||
be gathered across ranks first). If provided, it is saved directly instead of calling
|
|
||||||
``optimizer.state_dict()``. Only supported for a single optimizer. Defaults to None.
|
|
||||||
"""
|
"""
|
||||||
if isinstance(optimizer, dict):
|
if isinstance(optimizer, dict):
|
||||||
# Handle dictionary of optimizers
|
# Handle dictionary of optimizers
|
||||||
if optim_state_dict is not None:
|
|
||||||
raise ValueError("optim_state_dict is not supported for a dict of optimizers")
|
|
||||||
for name, opt in optimizer.items():
|
for name, opt in optimizer.items():
|
||||||
optimizer_dir = save_dir / name
|
optimizer_dir = save_dir / name
|
||||||
optimizer_dir.mkdir(exist_ok=True, parents=True)
|
optimizer_dir.mkdir(exist_ok=True, parents=True)
|
||||||
_save_single_optimizer_state(opt, optimizer_dir)
|
_save_single_optimizer_state(opt, optimizer_dir)
|
||||||
else:
|
else:
|
||||||
# Handle single optimizer
|
# Handle single optimizer
|
||||||
_save_single_optimizer_state(optimizer, save_dir, optim_state_dict=optim_state_dict)
|
_save_single_optimizer_state(optimizer, save_dir)
|
||||||
|
|
||||||
|
|
||||||
def _save_single_optimizer_state(
|
def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
|
||||||
optimizer: torch.optim.Optimizer, save_dir: Path, optim_state_dict: dict | None = None
|
|
||||||
) -> None:
|
|
||||||
"""Save a single optimizer's state to disk."""
|
"""Save a single optimizer's state to disk."""
|
||||||
state = dict(optim_state_dict) if optim_state_dict is not None else optimizer.state_dict()
|
state = optimizer.state_dict()
|
||||||
param_groups = state.pop("param_groups")
|
param_groups = state.pop("param_groups")
|
||||||
flat_state = flatten_dict(state)
|
flat_state = flatten_dict(state)
|
||||||
save_file(flat_state, save_dir / OPTIMIZER_STATE)
|
save_file(flat_state, save_dir / OPTIMIZER_STATE)
|
||||||
@@ -365,19 +356,3 @@ def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Pat
|
|||||||
|
|
||||||
optimizer.load_state_dict(loaded_state_dict)
|
optimizer.load_state_dict(loaded_state_dict)
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
def load_optimizer_state_dict(save_dir: Path) -> dict:
|
|
||||||
"""Read a saved optimizer state dict (safetensors + json) back into a plain dict.
|
|
||||||
|
|
||||||
Unlike `load_optimizer_state`, this does not load into an optimizer and preserves the original
|
|
||||||
``state`` keys verbatim (e.g. FSDP parameter FQNs, which are not integer-castable). It is used by
|
|
||||||
the FSDP resume path, where the full state must be resharded via `FSDP.optim_state_dict_to_load`
|
|
||||||
before being loaded into the (sharded) optimizer.
|
|
||||||
"""
|
|
||||||
flat_state = load_file(save_dir / OPTIMIZER_STATE)
|
|
||||||
state = unflatten_dict(flat_state)
|
|
||||||
return {
|
|
||||||
"state": state.get("state", {}),
|
|
||||||
"param_groups": load_json(save_dir / OPTIMIZER_PARAM_GROUPS),
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from typing import TypedDict, TypeVar, Unpack
|
|||||||
|
|
||||||
import packaging
|
import packaging
|
||||||
import safetensors
|
import safetensors
|
||||||
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download, save_torch_state_dict
|
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download
|
||||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||||
from huggingface_hub.errors import HfHubHTTPError
|
from huggingface_hub.errors import HfHubHTTPError
|
||||||
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
|
||||||
@@ -129,43 +129,10 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||||||
if not getattr(cls, "name", None):
|
if not getattr(cls, "name", None):
|
||||||
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
raise TypeError(f"Class {cls.__name__} must define 'name'")
|
||||||
|
|
||||||
def save_pretrained(
|
def _save_pretrained(self, save_directory: Path) -> None:
|
||||||
self,
|
|
||||||
save_directory: str | Path,
|
|
||||||
*,
|
|
||||||
state_dict: dict[str, Tensor] | None = None,
|
|
||||||
repo_id: str | None = None,
|
|
||||||
push_to_hub: bool = False,
|
|
||||||
card_kwargs: dict | None = None,
|
|
||||||
**push_to_hub_kwargs,
|
|
||||||
) -> str | None:
|
|
||||||
"""Save the policy to a directory (and optionally push to the Hub).
|
|
||||||
|
|
||||||
Overrides `HubMixin.save_pretrained` to add a `state_dict` argument (mirroring
|
|
||||||
`transformers.PreTrainedModel.save_pretrained`). Under FSDP, `self.state_dict()` would
|
|
||||||
return sharded tensors, so the caller gathers the full state dict via a cross-rank
|
|
||||||
collective and passes it here for `_save_pretrained` to write directly.
|
|
||||||
"""
|
|
||||||
save_directory = Path(save_directory)
|
|
||||||
save_directory.mkdir(parents=True, exist_ok=True)
|
|
||||||
self._save_pretrained(save_directory, state_dict=state_dict)
|
|
||||||
if push_to_hub:
|
|
||||||
if repo_id is None:
|
|
||||||
repo_id = save_directory.name
|
|
||||||
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _save_pretrained(self, save_directory: Path, state_dict: dict[str, Tensor] | None = None) -> None:
|
|
||||||
self.config._save_pretrained(save_directory)
|
self.config._save_pretrained(save_directory)
|
||||||
model_to_save = self.module if hasattr(self, "module") else self
|
model_to_save = self.module if hasattr(self, "module") else self
|
||||||
if state_dict is None:
|
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
||||||
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
|
|
||||||
return
|
|
||||||
# A pre-gathered (e.g. FSDP full) state dict was supplied: write it directly.
|
|
||||||
# `save_torch_state_dict` discards shared-tensor duplicates just like `save_model` does;
|
|
||||||
# pin `max_shard_size` above the total size so the output stays a single `model.safetensors`
|
|
||||||
total_bytes = sum(t.numel() * t.element_size() for t in state_dict.values())
|
|
||||||
save_torch_state_dict(state_dict, str(save_directory), max_shard_size=max(total_bytes, 1))
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
@@ -303,7 +270,6 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||||||
self,
|
self,
|
||||||
cfg: TrainPipelineConfig,
|
cfg: TrainPipelineConfig,
|
||||||
peft_model=None,
|
peft_model=None,
|
||||||
state_dict: dict[str, Tensor] | None = None,
|
|
||||||
):
|
):
|
||||||
api = HfApi()
|
api = HfApi()
|
||||||
repo_id = api.create_repo(
|
repo_id = api.create_repo(
|
||||||
@@ -321,8 +287,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
|||||||
peft_model.save_pretrained(saved_path)
|
peft_model.save_pretrained(saved_path)
|
||||||
self.config.save_pretrained(saved_path)
|
self.config.save_pretrained(saved_path)
|
||||||
else:
|
else:
|
||||||
# Calls _save_pretrained and stores model tensors
|
self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
|
||||||
self.save_pretrained(saved_path, state_dict=state_dict)
|
|
||||||
|
|
||||||
card = self.generate_model_card(
|
card = self.generate_model_card(
|
||||||
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg
|
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg
|
||||||
|
|||||||
@@ -34,10 +34,8 @@ from torch.optim import Optimizer
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.common.train_utils import (
|
from lerobot.common.train_utils import (
|
||||||
gather_fsdp_state_dicts,
|
|
||||||
get_step_checkpoint_dir,
|
get_step_checkpoint_dir,
|
||||||
get_step_identifier,
|
get_step_identifier,
|
||||||
load_fsdp_optimizer_state,
|
|
||||||
load_training_batch_size,
|
load_training_batch_size,
|
||||||
load_training_num_processes,
|
load_training_num_processes,
|
||||||
load_training_state,
|
load_training_state,
|
||||||
@@ -47,7 +45,8 @@ from lerobot.common.train_utils import (
|
|||||||
from lerobot.common.wandb_utils import WandBLogger
|
from lerobot.common.wandb_utils import WandBLogger
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state, make_dataset
|
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state
|
||||||
|
from lerobot.datasets.factory import make_train_eval_datasets
|
||||||
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
|
||||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||||
@@ -191,7 +190,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
|
|
||||||
require_package("accelerate", extra="training")
|
require_package("accelerate", extra="training")
|
||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from accelerate.utils import DistributedDataParallelKwargs, DistributedType
|
|
||||||
|
|
||||||
cfg.validate()
|
cfg.validate()
|
||||||
|
|
||||||
@@ -200,6 +198,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
|
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
|
||||||
# We set find_unused_parameters=True to handle models with conditional computation
|
# We set find_unused_parameters=True to handle models with conditional computation
|
||||||
if accelerator is None:
|
if accelerator is None:
|
||||||
|
from accelerate.utils import DistributedDataParallelKwargs
|
||||||
|
|
||||||
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
|
||||||
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
|
||||||
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
|
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
|
||||||
@@ -245,19 +245,19 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
# LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads.
|
# LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads.
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
logging.info("Creating dataset")
|
logging.info("Creating dataset")
|
||||||
dataset = make_dataset(cfg)
|
dataset, eval_dataset = make_train_eval_datasets(cfg)
|
||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
# Other ranks read from the shared copy populated by the main process.
|
# Other ranks read from the shared copy populated by the main process.
|
||||||
if not is_main_process:
|
if not is_main_process:
|
||||||
dataset = make_dataset(cfg)
|
dataset, eval_dataset = make_train_eval_datasets(cfg)
|
||||||
|
|
||||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||||
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
# On real-world data, no need to create an environment as evaluations are done outside train.py,
|
||||||
# using the eval.py instead, with gym_dora environment and dora-rs.
|
# using the eval.py instead, with gym_dora environment and dora-rs.
|
||||||
eval_env = None
|
eval_env = None
|
||||||
if cfg.eval_freq > 0 and cfg.env is not None and is_main_process:
|
if cfg.env_eval_freq > 0 and cfg.env is not None and is_main_process:
|
||||||
logging.info("Creating env")
|
logging.info("Creating env")
|
||||||
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
|
||||||
|
|
||||||
@@ -371,12 +371,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
step = 0 # number of policy updates (forward + backward + optim)
|
step = 0 # number of policy updates (forward + backward + optim)
|
||||||
|
|
||||||
if cfg.resume:
|
if cfg.resume:
|
||||||
# Under FSDP the optimizer state is sharded and must be loaded after `accelerator.prepare()`
|
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||||
# (see load_fsdp_optimizer_state below), so skip the optimizer here and load it then.
|
|
||||||
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
|
|
||||||
step, optimizer, lr_scheduler = load_training_state(
|
|
||||||
cfg.checkpoint_path, optimizer, lr_scheduler, load_optimizer=not is_fsdp
|
|
||||||
)
|
|
||||||
|
|
||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||||
@@ -412,6 +407,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
|
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
seed=cfg.seed if cfg.seed is not None else 0,
|
seed=cfg.seed if cfg.seed is not None else 0,
|
||||||
|
absolute_to_relative_idx=dataset.absolute_to_relative_idx,
|
||||||
)
|
)
|
||||||
if cfg.resume and step > 0:
|
if cfg.resume and step > 0:
|
||||||
# The resume offset depends on the (num_processes, batch_size) that produced `step`, so
|
# The resume offset depends on the (num_processes, batch_size) that produced `step`, so
|
||||||
@@ -461,17 +457,38 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Build eval dataloader if a held-out split exists
|
||||||
|
eval_dataloader = None
|
||||||
|
if eval_dataset is not None:
|
||||||
|
eval_ds = eval_dataset
|
||||||
|
if cfg.max_eval_samples > 0 and hasattr(eval_dataset, "hf_dataset"):
|
||||||
|
task_arr = eval_dataset.hf_dataset.data.column("task_index").to_numpy()
|
||||||
|
unique_tasks = sorted(set(task_arr.tolist()))
|
||||||
|
per_task = max(1, cfg.max_eval_samples // len(unique_tasks))
|
||||||
|
selected: list[int] = []
|
||||||
|
for t in unique_tasks:
|
||||||
|
frames = (task_arr == t).nonzero()[0][:per_task]
|
||||||
|
selected.extend(frames.tolist())
|
||||||
|
eval_ds = torch.utils.data.Subset(eval_dataset, selected)
|
||||||
|
|
||||||
|
eval_collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None
|
||||||
|
eval_dataloader = torch.utils.data.DataLoader(
|
||||||
|
eval_ds,
|
||||||
|
batch_size=cfg.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=cfg.num_workers,
|
||||||
|
pin_memory=device.type == "cuda",
|
||||||
|
drop_last=False,
|
||||||
|
collate_fn=eval_collate_fn,
|
||||||
|
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
|
||||||
|
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
||||||
|
)
|
||||||
|
|
||||||
# Prepare everything with accelerator
|
# Prepare everything with accelerator
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
|
||||||
policy, optimizer, dataloader, lr_scheduler
|
policy, optimizer, dataloader, lr_scheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
# FSDP optimizer state is sharded across ranks, so it can only be loaded once the optimizer and
|
|
||||||
# model are FSDP-wrapped (i.e. after `prepare`). Collective: every rank must participate.
|
|
||||||
if cfg.resume and accelerator.distributed_type == DistributedType.FSDP:
|
|
||||||
load_fsdp_optimizer_state(policy, optimizer, cfg.checkpoint_path)
|
|
||||||
|
|
||||||
dl_iter = cycle(dataloader)
|
dl_iter = cycle(dataloader)
|
||||||
|
|
||||||
policy.train()
|
policy.train()
|
||||||
@@ -546,7 +563,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
train_tracker.step()
|
train_tracker.step()
|
||||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
|
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
|
||||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||||
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
|
is_env_eval_step = cfg.env_eval_freq > 0 and step % cfg.env_eval_freq == 0
|
||||||
|
is_eval_step = cfg.eval_steps > 0 and eval_dataloader is not None and step % cfg.eval_steps == 0
|
||||||
|
|
||||||
if is_log_step:
|
if is_log_step:
|
||||||
# Collective reduce must run on every rank, before the main-process gate below.
|
# Collective reduce must run on every rank, before the main-process gate below.
|
||||||
@@ -569,15 +587,28 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
wandb_logger.log_dict(wandb_log_dict, step)
|
wandb_logger.log_dict(wandb_log_dict, step)
|
||||||
train_tracker.reset_averages()
|
train_tracker.reset_averages()
|
||||||
|
|
||||||
|
if is_eval_step:
|
||||||
|
policy.eval()
|
||||||
|
eval_loss_sum = 0.0
|
||||||
|
n_eval_batches = 0
|
||||||
|
with torch.no_grad(), accelerator.autocast():
|
||||||
|
for eval_batch in eval_dataloader:
|
||||||
|
for cam_key in dataset.meta.camera_keys:
|
||||||
|
if cam_key in eval_batch and eval_batch[cam_key].dtype == torch.uint8:
|
||||||
|
eval_batch[cam_key] = eval_batch[cam_key].to(dtype=torch.float32) / 255.0
|
||||||
|
eval_batch = preprocessor(eval_batch)
|
||||||
|
loss, _ = policy.forward(eval_batch)
|
||||||
|
eval_loss_sum += loss.item()
|
||||||
|
n_eval_batches += 1
|
||||||
|
eval_loss = eval_loss_sum / max(n_eval_batches, 1)
|
||||||
|
policy.train()
|
||||||
|
|
||||||
|
if is_main_process:
|
||||||
|
logging.info(f"step {step}: eval_loss={eval_loss:.4f}")
|
||||||
|
if wandb_logger:
|
||||||
|
wandb_logger.log_dict({"eval_loss": eval_loss}, step=step, mode="eval")
|
||||||
|
|
||||||
if cfg.save_checkpoint and is_saving_step:
|
if cfg.save_checkpoint and is_saving_step:
|
||||||
# Under FSDP, gathering the full model + optimizer state dicts is a cross-rank collective,
|
|
||||||
# so all ranks must participate; rank 0 then writes the materialized dicts. For DDP /
|
|
||||||
# single-GPU the state dicts are saved the normal way inside save_checkpoint.
|
|
||||||
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
|
|
||||||
if is_fsdp:
|
|
||||||
model_state_dict, optim_state_dict = gather_fsdp_state_dicts(policy, optimizer)
|
|
||||||
else:
|
|
||||||
model_state_dict, optim_state_dict = None, None
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
logging.info(f"Checkpoint policy after step {step}")
|
logging.info(f"Checkpoint policy after step {step}")
|
||||||
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
|
||||||
@@ -592,8 +623,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
num_processes=accelerator.num_processes,
|
num_processes=accelerator.num_processes,
|
||||||
batch_size=cfg.batch_size,
|
batch_size=cfg.batch_size,
|
||||||
model_state_dict=model_state_dict,
|
|
||||||
optim_state_dict=optim_state_dict,
|
|
||||||
)
|
)
|
||||||
update_last_checkpoint(checkpoint_dir)
|
update_last_checkpoint(checkpoint_dir)
|
||||||
if wandb_logger:
|
if wandb_logger:
|
||||||
@@ -601,7 +630,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
if cfg.env and is_eval_step:
|
if cfg.env and is_env_eval_step:
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
step_id = get_step_identifier(step, cfg.steps)
|
step_id = get_step_identifier(step, cfg.steps)
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
@@ -656,8 +685,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
if eval_env:
|
if eval_env:
|
||||||
close_envs(eval_env)
|
close_envs(eval_env)
|
||||||
|
|
||||||
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
|
|
||||||
model_state_dict = accelerator.get_state_dict(policy) if is_fsdp else None
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
logging.info("End of training")
|
logging.info("End of training")
|
||||||
|
|
||||||
@@ -667,7 +694,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
if not cfg.is_reward_model_training and cfg.policy.use_peft:
|
if not cfg.is_reward_model_training and cfg.policy.use_peft:
|
||||||
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model)
|
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model)
|
||||||
else:
|
else:
|
||||||
unwrapped_model.push_model_to_hub(cfg, state_dict=model_state_dict)
|
unwrapped_model.push_model_to_hub(cfg)
|
||||||
preprocessor.push_to_hub(active_cfg.repo_id)
|
preprocessor.push_to_hub(active_cfg.repo_id)
|
||||||
postprocessor.push_to_hub(active_cfg.repo_id)
|
postprocessor.push_to_hub(active_cfg.repo_id)
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from lerobot.optim.optimizers import (
|
|||||||
MultiAdamConfig,
|
MultiAdamConfig,
|
||||||
SGDConfig,
|
SGDConfig,
|
||||||
load_optimizer_state,
|
load_optimizer_state,
|
||||||
load_optimizer_state_dict,
|
|
||||||
save_optimizer_state,
|
save_optimizer_state,
|
||||||
)
|
)
|
||||||
from lerobot.utils.constants import (
|
from lerobot.utils.constants import (
|
||||||
@@ -66,44 +65,6 @@ def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path):
|
|||||||
torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict())
|
torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict())
|
||||||
|
|
||||||
|
|
||||||
def test_save_and_load_fsdp_optimizer_state_dict_roundtrip(tmp_path):
|
|
||||||
"""The FSDP full optimizer state dict is keyed by parameter FQNs (dotted strings), not the
|
|
||||||
integer indices of the single-GPU path. Verify it survives the safetensors save -> read
|
|
||||||
round-trip used by the FSDP save/resume path (save_optimizer_state(optim_state_dict=...) then
|
|
||||||
load_optimizer_state_dict), which the flatten/unflatten "/" separator must not corrupt."""
|
|
||||||
full_osd = {
|
|
||||||
"state": {
|
|
||||||
"model.layers.0.weight": {
|
|
||||||
"step": torch.tensor(3.0),
|
|
||||||
"exp_avg": torch.randn(4, 4),
|
|
||||||
"exp_avg_sq": torch.randn(4, 4),
|
|
||||||
},
|
|
||||||
"model.layers.0.bias": {
|
|
||||||
"step": torch.tensor(3.0),
|
|
||||||
"exp_avg": torch.randn(4),
|
|
||||||
"exp_avg_sq": torch.randn(4),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"param_groups": [
|
|
||||||
{"lr": 1e-4, "betas": [0.9, 0.999], "eps": 1e-8, "weight_decay": 0.0, "params": [0, 1]}
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
save_optimizer_state(
|
|
||||||
torch.optim.Adam([torch.nn.Parameter(torch.randn(1))]), tmp_path, optim_state_dict=full_osd
|
|
||||||
)
|
|
||||||
assert (tmp_path / OPTIMIZER_STATE).is_file()
|
|
||||||
assert (tmp_path / OPTIMIZER_PARAM_GROUPS).is_file()
|
|
||||||
|
|
||||||
loaded = load_optimizer_state_dict(tmp_path)
|
|
||||||
# FQN keys must be preserved verbatim (not int-cast, not split on their dots).
|
|
||||||
assert set(loaded["state"].keys()) == set(full_osd["state"].keys())
|
|
||||||
for fqn, sub in full_osd["state"].items():
|
|
||||||
for k, v in sub.items():
|
|
||||||
torch.testing.assert_close(loaded["state"][fqn][k], v)
|
|
||||||
assert loaded["param_groups"] == full_osd["param_groups"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def base_params_dict():
|
def base_params_dict():
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ import torch
|
|||||||
|
|
||||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||||
|
|
||||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
@@ -301,29 +300,6 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
|
|||||||
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
||||||
|
|
||||||
|
|
||||||
def test_save_pretrained_with_state_dict(dummy_dataset_metadata, tmp_path):
|
|
||||||
"""Exercise the FSDP checkpoint path: save_pretrained with a pre-gathered state_dict."""
|
|
||||||
policy_cls = get_policy_class("act")
|
|
||||||
policy_cfg = make_policy_config("act")
|
|
||||||
features = dataset_to_policy_features(dummy_dataset_metadata.features)
|
|
||||||
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
|
||||||
policy_cfg.input_features = {
|
|
||||||
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
|
|
||||||
}
|
|
||||||
policy = policy_cls(policy_cfg)
|
|
||||||
policy.to(policy_cfg.device)
|
|
||||||
|
|
||||||
save_dir = tmp_path / "fsdp_state_dict"
|
|
||||||
policy.save_pretrained(save_dir, state_dict=policy.state_dict())
|
|
||||||
|
|
||||||
# A single, unsharded safetensors file (no sharded set + index).
|
|
||||||
assert (save_dir / SAFETENSORS_SINGLE_FILE).is_file()
|
|
||||||
assert not (save_dir / f"{SAFETENSORS_SINGLE_FILE}.index.json").exists()
|
|
||||||
|
|
||||||
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
|
|
||||||
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("multikey", [True, False])
|
@pytest.mark.parametrize("multikey", [True, False])
|
||||||
def test_multikey_construction(multikey: bool):
|
def test_multikey_construction(multikey: bool):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -58,46 +58,7 @@ def download_dataset(repo_id, episodes):
|
|||||||
print(f"Dataset {repo_id} downloaded successfully")
|
print(f"Dataset {repo_id} downloaded successfully")
|
||||||
|
|
||||||
|
|
||||||
def _write_multi_gpu_config(f, num_processes):
|
def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
|
||||||
f.write("compute_environment: LOCAL_MACHINE\n")
|
|
||||||
f.write("distributed_type: MULTI_GPU\n")
|
|
||||||
f.write("mixed_precision: 'no'\n")
|
|
||||||
f.write(f"num_processes: {num_processes}\n")
|
|
||||||
f.write("use_cpu: false\n")
|
|
||||||
f.write("gpu_ids: all\n")
|
|
||||||
f.write("downcast_bf16: 'no'\n")
|
|
||||||
f.write("machine_rank: 0\n")
|
|
||||||
f.write("main_training_function: main\n")
|
|
||||||
f.write("num_machines: 1\n")
|
|
||||||
f.write("rdzv_backend: static\n")
|
|
||||||
f.write("same_network: true\n")
|
|
||||||
|
|
||||||
|
|
||||||
def _write_fsdp_config(f, num_processes):
|
|
||||||
# FSDP1 with FULL_SHARD (ZeRO-3-equivalent) and FULL_STATE_DICT, matching
|
|
||||||
# docs/source/multi_gpu_training.mdx. ACT's repeated transformer blocks are the wrap units;
|
|
||||||
# fsdp_use_orig_params is required because LeRobot builds the optimizer before prepare().
|
|
||||||
f.write("compute_environment: LOCAL_MACHINE\n")
|
|
||||||
f.write("distributed_type: FSDP\n")
|
|
||||||
f.write("mixed_precision: 'no'\n")
|
|
||||||
f.write(f"num_processes: {num_processes}\n")
|
|
||||||
f.write("use_cpu: false\n")
|
|
||||||
f.write("gpu_ids: all\n")
|
|
||||||
f.write("machine_rank: 0\n")
|
|
||||||
f.write("main_training_function: main\n")
|
|
||||||
f.write("num_machines: 1\n")
|
|
||||||
f.write("rdzv_backend: static\n")
|
|
||||||
f.write("same_network: true\n")
|
|
||||||
f.write("fsdp_config:\n")
|
|
||||||
f.write(" fsdp_version: 1\n")
|
|
||||||
f.write(" fsdp_sharding_strategy: FULL_SHARD\n")
|
|
||||||
f.write(" fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n")
|
|
||||||
f.write(" fsdp_transformer_layer_cls_to_wrap: ACTEncoderLayer,ACTDecoderLayer\n")
|
|
||||||
f.write(" fsdp_use_orig_params: true\n")
|
|
||||||
f.write(" fsdp_state_dict_type: FULL_STATE_DICT\n")
|
|
||||||
|
|
||||||
|
|
||||||
def run_accelerate_training(config_args, num_processes=4, temp_dir=None, distributed_type="MULTI_GPU"):
|
|
||||||
"""
|
"""
|
||||||
Helper function to run training with accelerate launch.
|
Helper function to run training with accelerate launch.
|
||||||
|
|
||||||
@@ -105,7 +66,6 @@ def run_accelerate_training(config_args, num_processes=4, temp_dir=None, distrib
|
|||||||
config_args: List of config arguments to pass to lerobot_train.py
|
config_args: List of config arguments to pass to lerobot_train.py
|
||||||
num_processes: Number of processes (GPUs) to use
|
num_processes: Number of processes (GPUs) to use
|
||||||
temp_dir: Temporary directory for outputs
|
temp_dir: Temporary directory for outputs
|
||||||
distributed_type: "MULTI_GPU" (DDP) or "FSDP" — selects the generated accelerate config.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
subprocess.CompletedProcess result
|
subprocess.CompletedProcess result
|
||||||
@@ -115,10 +75,18 @@ def run_accelerate_training(config_args, num_processes=4, temp_dir=None, distrib
|
|||||||
|
|
||||||
# Write YAML config
|
# Write YAML config
|
||||||
with open(config_path, "w") as f:
|
with open(config_path, "w") as f:
|
||||||
if distributed_type == "FSDP":
|
f.write("compute_environment: LOCAL_MACHINE\n")
|
||||||
_write_fsdp_config(f, num_processes)
|
f.write("distributed_type: MULTI_GPU\n")
|
||||||
else:
|
f.write("mixed_precision: 'no'\n")
|
||||||
_write_multi_gpu_config(f, num_processes)
|
f.write(f"num_processes: {num_processes}\n")
|
||||||
|
f.write("use_cpu: false\n")
|
||||||
|
f.write("gpu_ids: all\n")
|
||||||
|
f.write("downcast_bf16: 'no'\n")
|
||||||
|
f.write("machine_rank: 0\n")
|
||||||
|
f.write("main_training_function: main\n")
|
||||||
|
f.write("num_machines: 1\n")
|
||||||
|
f.write("rdzv_backend: static\n")
|
||||||
|
f.write("same_network: true\n")
|
||||||
|
|
||||||
cmd = [
|
cmd = [
|
||||||
"accelerate",
|
"accelerate",
|
||||||
@@ -166,7 +134,7 @@ class TestMultiGPUTraining:
|
|||||||
f"--output_dir={output_dir}",
|
f"--output_dir={output_dir}",
|
||||||
"--batch_size=4",
|
"--batch_size=4",
|
||||||
"--steps=10",
|
"--steps=10",
|
||||||
"--eval_freq=-1",
|
"--env_eval_freq=-1",
|
||||||
"--log_freq=5",
|
"--log_freq=5",
|
||||||
"--save_freq=10",
|
"--save_freq=10",
|
||||||
"--seed=42",
|
"--seed=42",
|
||||||
@@ -209,7 +177,7 @@ class TestMultiGPUTraining:
|
|||||||
f"--output_dir={output_dir}",
|
f"--output_dir={output_dir}",
|
||||||
"--batch_size=4",
|
"--batch_size=4",
|
||||||
"--steps=20",
|
"--steps=20",
|
||||||
"--eval_freq=-1",
|
"--env_eval_freq=-1",
|
||||||
"--log_freq=5",
|
"--log_freq=5",
|
||||||
"--save_freq=10",
|
"--save_freq=10",
|
||||||
"--seed=42",
|
"--seed=42",
|
||||||
@@ -243,66 +211,3 @@ class TestMultiGPUTraining:
|
|||||||
# Verify optimizer state exists
|
# Verify optimizer state exists
|
||||||
optimizer_state = training_state_dir / "optimizer_state.safetensors"
|
optimizer_state = training_state_dir / "optimizer_state.safetensors"
|
||||||
assert optimizer_state.exists(), f"No optimizer state in checkpoint {checkpoint_dir}"
|
assert optimizer_state.exists(), f"No optimizer state in checkpoint {checkpoint_dir}"
|
||||||
|
|
||||||
def test_fsdp_optimizer_save_and_resume(self):
|
|
||||||
"""
|
|
||||||
Test that FSDP saves the (gathered) optimizer state and can resume from it.
|
|
||||||
|
|
||||||
Trains a few steps under FSDP, verifies the gathered optimizer state is written next to the
|
|
||||||
rest of the training state, then resumes from the checkpoint for more steps and checks it
|
|
||||||
completes without shape/key errors in the FSDP optimizer load path.
|
|
||||||
"""
|
|
||||||
# Pre-download dataset to avoid race conditions
|
|
||||||
download_dataset("lerobot/pusht", episodes=[0])
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
output_dir = Path(temp_dir) / "outputs"
|
|
||||||
|
|
||||||
config_args = [
|
|
||||||
"--dataset.repo_id=lerobot/pusht",
|
|
||||||
"--dataset.episodes=[0]",
|
|
||||||
"--policy.type=act",
|
|
||||||
"--policy.device=cuda",
|
|
||||||
"--policy.push_to_hub=false",
|
|
||||||
f"--output_dir={output_dir}",
|
|
||||||
"--batch_size=4",
|
|
||||||
"--steps=10",
|
|
||||||
"--eval_freq=-1",
|
|
||||||
"--log_freq=5",
|
|
||||||
"--save_freq=10",
|
|
||||||
"--seed=42",
|
|
||||||
"--num_workers=0",
|
|
||||||
]
|
|
||||||
|
|
||||||
result = run_accelerate_training(
|
|
||||||
config_args, num_processes=2, temp_dir=temp_dir, distributed_type="FSDP"
|
|
||||||
)
|
|
||||||
assert result.returncode == 0, (
|
|
||||||
f"FSDP training failed:\nSTDOUT:\n{result.stdout}\n\nSTDERR:\n{result.stderr}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# The gathered optimizer state must be written under FSDP (proves the save collective ran),
|
|
||||||
# in the same safetensors format as single-GPU training.
|
|
||||||
training_state_dir = output_dir / "checkpoints" / "last" / "training_state"
|
|
||||||
optimizer_state = training_state_dir / "optimizer_state.safetensors"
|
|
||||||
optimizer_param_groups = training_state_dir / "optimizer_param_groups.json"
|
|
||||||
assert optimizer_state.exists(), f"FSDP optimizer state not saved in {training_state_dir}"
|
|
||||||
assert optimizer_param_groups.exists(), (
|
|
||||||
f"FSDP optimizer param groups not saved in {training_state_dir}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Resume from the checkpoint for more steps. A successful run proves load_fsdp_optimizer
|
|
||||||
# accepts the saved state and reshards it without shape/key errors.
|
|
||||||
resume_config = output_dir / "checkpoints" / "last" / "pretrained_model" / "train_config.json"
|
|
||||||
resume_args = [
|
|
||||||
f"--config_path={resume_config}",
|
|
||||||
"--resume=true",
|
|
||||||
"--steps=20",
|
|
||||||
]
|
|
||||||
resume_result = run_accelerate_training(
|
|
||||||
resume_args, num_processes=2, temp_dir=temp_dir, distributed_type="FSDP"
|
|
||||||
)
|
|
||||||
assert resume_result.returncode == 0, (
|
|
||||||
f"FSDP resume failed:\nSTDOUT:\n{resume_result.stdout}\n\nSTDERR:\n{resume_result.stderr}"
|
|
||||||
)
|
|
||||||
assert "End of training" in resume_result.stdout or "End of training" in resume_result.stderr
|
|
||||||
|
|||||||
@@ -136,18 +136,3 @@ def test_save_load_training_state(tmp_path, optimizer, scheduler):
|
|||||||
assert loaded_step == 10
|
assert loaded_step == 10
|
||||||
assert loaded_optimizer is optimizer
|
assert loaded_optimizer is optimizer
|
||||||
assert loaded_scheduler is scheduler
|
assert loaded_scheduler is scheduler
|
||||||
|
|
||||||
|
|
||||||
def test_load_training_state_skip_optimizer(tmp_path, optimizer, scheduler):
|
|
||||||
# FSDP loads optimizer separately (after accelerator.prepare)
|
|
||||||
# load_training_state(load_optimizer=False) must restore step + scheduler but leave the
|
|
||||||
# optimizer untouched and never touch the on-disk optimizer state.
|
|
||||||
save_training_state(tmp_path, 10, optimizer, scheduler)
|
|
||||||
with patch("lerobot.common.train_utils.load_optimizer_state") as mock_load_optimizer_state:
|
|
||||||
loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(
|
|
||||||
tmp_path, optimizer, scheduler, load_optimizer=False
|
|
||||||
)
|
|
||||||
mock_load_optimizer_state.assert_not_called()
|
|
||||||
assert loaded_step == 10
|
|
||||||
assert loaded_optimizer is optimizer
|
|
||||||
assert loaded_scheduler is scheduler
|
|
||||||
|
|||||||
Reference in New Issue
Block a user