Compare commits

..

5 Commits

Author SHA1 Message Date
Khalil Meftah 9449e68725 fix(datasets): Move the remapping into EpisodeAwareSampler via absolute_to_relative_idx 2026-06-16 18:32:48 +02:00
Khalil Meftah 2b83956eb5 fix(train): vectorize eval subset selection for max_eval_samples 2026-06-16 16:22:55 +02:00
Khalil Meftah 7309790d56 fix(datasets): remap absolute indices in __getitem__ for filtered datasets 2026-06-16 15:15:11 +02:00
Khalil Meftah 2201401c99 feat(training): add inline offline validation with train/eval split
- Add eval_split config for balanced per-task holdout
- Add eval_steps for periodic inline eval loss computation
- Add max_eval_samples to cap eval cost
2026-06-14 21:29:54 +02:00
Khalil Meftah 64773e7b22 refactor(training): rename eval_freq to env_eval_freq
- Rename eval_freq to env_eval_freq to distinguish sim environment evaluation from offline loss evaluation.
2026-06-14 14:19:25 +02:00
26 changed files with 214 additions and 451 deletions
+3 -3
View File
@@ -167,9 +167,9 @@ jobs:
# ── LIBERO TRAIN+EVAL SMOKE ────────────────────────────────────────────── # ── LIBERO TRAIN+EVAL SMOKE ──────────────────────────────────────────────
# Train SmolVLA for 1 step (batch_size=1, dataset episode 0 only) then # Train SmolVLA for 1 step (batch_size=1, dataset episode 0 only) then
# immediately runs eval inside the training loop (eval_freq=1, 1 episode). # immediately runs eval inside the training loop (env_eval_freq=1, 1 episode).
# Tests the full train→eval-within-training pipeline end-to-end. # Tests the full train→eval-within-training pipeline end-to-end.
- name: Run Libero train+eval smoke (1 step, eval_freq=1) - name: Run Libero train+eval smoke (1 step, env_eval_freq=1)
if: env.HF_USER_TOKEN != '' if: env.HF_USER_TOKEN != ''
run: | run: |
docker run --name libero-train-smoke --gpus all \ docker run --name libero-train-smoke --gpus all \
@@ -196,7 +196,7 @@ jobs:
--output_dir=/tmp/train-smoke \ --output_dir=/tmp/train-smoke \
--steps=1 \ --steps=1 \
--batch_size=1 \ --batch_size=1 \
--eval_freq=1 \ --env_eval_freq=1 \
--eval.n_episodes=1 \ --eval.n_episodes=1 \
--eval.batch_size=1 \ --eval.batch_size=1 \
--eval.use_async_envs=false \ --eval.use_async_envs=false \
+4 -4
View File
@@ -58,7 +58,7 @@ test-act-ete-train:
--dataset.episodes="[0]" \ --dataset.episodes="[0]" \
--batch_size=2 \ --batch_size=2 \
--steps=4 \ --steps=4 \
--eval_freq=2 \ --env_eval_freq=2 \
--eval.n_episodes=1 \ --eval.n_episodes=1 \
--eval.batch_size=1 \ --eval.batch_size=1 \
--save_freq=2 \ --save_freq=2 \
@@ -96,7 +96,7 @@ test-diffusion-ete-train:
--dataset.episodes="[0]" \ --dataset.episodes="[0]" \
--batch_size=2 \ --batch_size=2 \
--steps=2 \ --steps=2 \
--eval_freq=2 \ --env_eval_freq=2 \
--eval.n_episodes=1 \ --eval.n_episodes=1 \
--eval.batch_size=1 \ --eval.batch_size=1 \
--save_checkpoint=true \ --save_checkpoint=true \
@@ -126,7 +126,7 @@ test-tdmpc-ete-train:
--dataset.episodes="[0]" \ --dataset.episodes="[0]" \
--batch_size=2 \ --batch_size=2 \
--steps=2 \ --steps=2 \
--eval_freq=2 \ --env_eval_freq=2 \
--eval.n_episodes=1 \ --eval.n_episodes=1 \
--eval.batch_size=1 \ --eval.batch_size=1 \
--save_checkpoint=true \ --save_checkpoint=true \
@@ -161,7 +161,7 @@ test-smolvla-ete-train:
--dataset.episodes="[0]" \ --dataset.episodes="[0]" \
--batch_size=2 \ --batch_size=2 \
--steps=4 \ --steps=4 \
--eval_freq=2 \ --env_eval_freq=2 \
--eval.n_episodes=1 \ --eval.n_episodes=1 \
--eval.batch_size=1 \ --eval.batch_size=1 \
--save_freq=2 \ --save_freq=2 \
+1 -1
View File
@@ -719,7 +719,7 @@ Example configuration for training the [reward classifier](https://huggingface.c
"num_workers": 4, "num_workers": 4,
"steps": 5000, "steps": 5000,
"log_freq": 10, "log_freq": 10,
"eval_freq": 1000, "env_eval_freq": 1000,
"save_freq": 1000, "save_freq": 1000,
"save_checkpoint": true, "save_checkpoint": true,
"seed": 2, "seed": 2,
+1 -1
View File
@@ -143,7 +143,7 @@ lerobot-train \
--batch_size=4 \ --batch_size=4 \
--eval.batch_size=1 \ --eval.batch_size=1 \
--eval.n_episodes=1 \ --eval.n_episodes=1 \
--eval_freq=1000 --env_eval_freq=1000
``` ```
## Reproducing published results ## Reproducing published results
+1 -1
View File
@@ -173,7 +173,7 @@ lerobot-train \
--batch_size=4 \ --batch_size=4 \
--eval.batch_size=1 \ --eval.batch_size=1 \
--eval.n_episodes=1 \ --eval.n_episodes=1 \
--eval_freq=1000 --env_eval_freq=1000
``` ```
## Relationship to LIBERO ## Relationship to LIBERO
+2 -2
View File
@@ -120,11 +120,11 @@ lerobot-train \
--batch_size=4 \ --batch_size=4 \
--eval.batch_size=1 \ --eval.batch_size=1 \
--eval.n_episodes=1 \ --eval.n_episodes=1 \
--eval_freq=1000 --env_eval_freq=1000
``` ```
## Practical tips ## Practical tips
- Use the one-hot task conditioning for multi-task training (MT10/MT50 conventions) so policies have explicit task context. - Use the one-hot task conditioning for multi-task training (MT10/MT50 conventions) so policies have explicit task context.
- Inspect the dataset task descriptions and the `info["is_success"]` keys when writing post-processing or logging so your success metrics line up with the benchmark. - Inspect the dataset task descriptions and the `info["is_success"]` keys when writing post-processing or logging so your success metrics line up with the benchmark.
- Adjust `batch_size`, `steps`, and `eval_freq` to match your compute budget. - Adjust `batch_size`, `steps`, and `env_eval_freq` to match your compute budget.
+2 -2
View File
@@ -103,7 +103,7 @@ accelerate launch \
--batch_size=32 \ --batch_size=32 \
--num_workers=4 \ --num_workers=4 \
--log_freq=20 \ --log_freq=20 \
--eval_freq=-1 \ --env_eval_freq=-1 \
--save_checkpoint=true \ --save_checkpoint=true \
--save_freq=2000 --save_freq=2000
``` ```
@@ -142,7 +142,7 @@ accelerate launch \
--batch_size=32 \ --batch_size=32 \
--num_workers=4 \ --num_workers=4 \
--log_freq=20 \ --log_freq=20 \
--eval_freq=-1 \ --env_eval_freq=-1 \
--save_checkpoint=true \ --save_checkpoint=true \
--save_freq=2000 --save_freq=2000
``` ```
-55
View File
@@ -113,61 +113,6 @@ accelerate launch --num_processes=2 $(which lerobot-train) \
--policy=act --policy=act
``` ```
## Training Large Models with FSDP
DDP replicates the full model on every GPU, so a model that doesn't fit on one GPU won't fit under
DDP either. For large models, use **FSDP** (Fully Sharded Data Parallel), which shards parameters,
gradients, and optimizer state across GPUs. See the [accelerate FSDP guide](https://huggingface.co/docs/accelerate/usage_guides/fsdp) for background.
An example on how to launch LeRobot training with FSDP across 4 GPUs (1 machine):
```bash
accelerate launch --config_file fsdp.yaml --num_processes=4 $(which lerobot-train) \
--dataset.repo_id=${HF_USER}/my_dataset \
--policy.type=<your_policy> \
--output_dir=outputs/train/my_policy_fsdp
```
A minimal `fsdp.yaml` (FSDP1; shards params/grads/optimizer — ZeRO-3-equivalent):
```yaml
compute_environment: LOCAL_MACHINE
distributed_type: FSDP
mixed_precision: bf16
num_machines: 1
num_processes: 4
fsdp_config:
fsdp_version: 1
fsdp_sharding_strategy: FULL_SHARD # params + grads + optimizer (ZeRO-3)
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: <YourTransformerBlock> # repeated block class to shard
fsdp_use_orig_params: true # required: optimizer is built pre-prepare
fsdp_state_dict_type: FULL_STATE_DICT
```
Set `fsdp_transformer_layer_cls_to_wrap` to your model's repeated transformer-block class so each
block is sharded as its own unit. `fsdp_use_orig_params: true` is required because LeRobot builds the
optimizer before `accelerator.prepare()`.
### FSDP checkpoints
LeRobot gathers the full state dict across all ranks and the main process writes it as a single
`model.safetensors`, loadable as usual with `Policy.from_pretrained(...)`. Two things to look out for:
- **Checkpoints store fp32 weights.** Under mixed precision (`bf16`/`fp16`) FSDP keeps an fp32 master
copy, and the checkpoint saves it (~2× the bf16 size on disk) so training can resume consistently
with the fp32 optimizer state; `from_pretrained` casts back to the policy dtype on load. FSDP-specific
caveat: an fp32 checkpoint is materialized in full precision on the target device _before_ casting,
so loading it for inference on a tight GPU can OOM even when the bf16 model would fit — load on CPU
first, or cast `model.safetensors` to the deployment dtype offline.
- The sharded optimizer state is gathered into a full (world-size-independent) state dict and saved
alongside the model in the same `optimizer_state.safetensors` / `optimizer_param_groups.json`
format as single-GPU training, so **resume-from-checkpoint is supported** with `--resume=true`.
Resume reshards both the model and the optimizer state to the _current_ FSDP topology, so you can
resume an FSDP checkpoint on a different number of GPUs. Note that the data sampler is only
sample-exact when the world size and batch size match the original run (a warning is logged
otherwise); the optimizer/model state itself is unaffected.
## Notes ## Notes
- The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration. - The `--policy.use_amp` flag in `lerobot-train` is only used when **not** running with accelerate. When using accelerate, mixed precision is controlled by accelerate's configuration.
+1 -1
View File
@@ -314,7 +314,7 @@ lerobot-train \
--steps=30000 \ --steps=30000 \
--save_freq=1000 \ --save_freq=1000 \
--log_freq=100 \ --log_freq=100 \
--eval_freq=1000 \ --env_eval_freq=1000 \
--policy.type=multi_task_dit \ --policy.type=multi_task_dit \
--policy.device=cuda \ --policy.device=cuda \
--policy.horizon=32 \ --policy.horizon=32 \
+1 -1
View File
@@ -166,7 +166,7 @@ lerobot-train \
--output_dir=./outputs/smolvla_robocasa_CloseFridge \ --output_dir=./outputs/smolvla_robocasa_CloseFridge \
--steps=100000 \ --steps=100000 \
--batch_size=4 \ --batch_size=4 \
--eval_freq=5000 \ --env_eval_freq=5000 \
--eval.batch_size=1 \ --eval.batch_size=1 \
--eval.n_episodes=5 \ --eval.n_episodes=5 \
--save_freq=10000 --save_freq=10000
+1 -1
View File
@@ -165,7 +165,7 @@ lerobot-train \
--output_dir=./outputs/smolvla_vlabench_primitive \ --output_dir=./outputs/smolvla_vlabench_primitive \
--steps=100000 \ --steps=100000 \
--batch_size=4 \ --batch_size=4 \
--eval_freq=5000 \ --env_eval_freq=5000 \
--eval.batch_size=1 \ --eval.batch_size=1 \
--eval.n_episodes=1 \ --eval.n_episodes=1 \
--save_freq=10000 --save_freq=10000
+5 -83
View File
@@ -21,7 +21,6 @@ from torch.optim.lr_scheduler import LRScheduler
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
from lerobot.optim import ( from lerobot.optim import (
load_optimizer_state, load_optimizer_state,
load_optimizer_state_dict,
load_scheduler_state, load_scheduler_state,
save_optimizer_state, save_optimizer_state,
save_scheduler_state, save_scheduler_state,
@@ -99,8 +98,6 @@ def save_checkpoint(
postprocessor: PolicyProcessorPipeline | None = None, postprocessor: PolicyProcessorPipeline | None = None,
num_processes: int | None = None, num_processes: int | None = None,
batch_size: int | None = None, batch_size: int | None = None,
model_state_dict: dict | None = None,
optim_state_dict: dict | None = None,
) -> None: ) -> None:
"""This function creates the following directory structure: """This function creates the following directory structure:
@@ -130,18 +127,9 @@ def save_checkpoint(
resume. Defaults to None (not recorded). resume. Defaults to None (not recorded).
batch_size (int | None, optional): Per-process batch size to record for sample-exact batch_size (int | None, optional): Per-process batch size to record for sample-exact
resume. Defaults to None (not recorded). resume. Defaults to None (not recorded).
model_state_dict: Pre-gathered full (unsharded) model state dict. Required under FSDP,
where `policy.state_dict()` would return sharded tensors; the caller gathers it via a
cross-rank collective and passes it here so rank 0 can write it directly. It holds
FSDP's fp32 master weights and is saved as-is (the loader casts to the policy dtype on
read). When None (DDP / single-GPU), the model is saved the normal way. Defaults to None.
optim_state_dict: Pre-gathered full (unsharded) optimizer state dict. Required under FSDP
(gathered alongside `model_state_dict` via `gather_fsdp_state_dicts`); saved in the same
safetensors format as the single-GPU path. When None, `optimizer.state_dict()` is used.
Defaults to None.
""" """
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir, state_dict=model_state_dict) policy.save_pretrained(pretrained_dir)
cfg.save_pretrained(pretrained_dir) cfg.save_pretrained(pretrained_dir)
if cfg.peft is not None: if cfg.peft is not None:
# When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the # When using PEFT, policy.save_pretrained will only write the adapter weights + config, not the
@@ -152,13 +140,7 @@ def save_checkpoint(
if postprocessor is not None: if postprocessor is not None:
postprocessor.save_pretrained(pretrained_dir) postprocessor.save_pretrained(pretrained_dir)
save_training_state( save_training_state(
checkpoint_dir, checkpoint_dir, step, optimizer, scheduler, num_processes=num_processes, batch_size=batch_size
step,
optimizer,
scheduler,
num_processes=num_processes,
batch_size=batch_size,
optim_state_dict=optim_state_dict,
) )
@@ -169,7 +151,6 @@ def save_training_state(
scheduler: LRScheduler | None = None, scheduler: LRScheduler | None = None,
num_processes: int | None = None, num_processes: int | None = None,
batch_size: int | None = None, batch_size: int | None = None,
optim_state_dict: dict | None = None,
) -> None: ) -> None:
""" """
Saves the training step, optimizer state, scheduler state, and rng state. Saves the training step, optimizer state, scheduler state, and rng state.
@@ -183,21 +164,19 @@ def save_training_state(
Defaults to None. Defaults to None.
num_processes (int | None, optional): Distributed world size to record. Defaults to None. num_processes (int | None, optional): Distributed world size to record. Defaults to None.
batch_size (int | None, optional): Per-process batch size to record. Defaults to None. batch_size (int | None, optional): Per-process batch size to record. Defaults to None.
optim_state_dict: Pre-gathered full optimizer state dict (for FSDP). Saved instead of
`optimizer.state_dict()` when provided. Defaults to None.
""" """
save_dir = checkpoint_dir / TRAINING_STATE_DIR save_dir = checkpoint_dir / TRAINING_STATE_DIR
save_dir.mkdir(parents=True, exist_ok=True) save_dir.mkdir(parents=True, exist_ok=True)
save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size) save_training_step(train_step, save_dir, num_processes=num_processes, batch_size=batch_size)
save_rng_state(save_dir) save_rng_state(save_dir)
if optimizer is not None: if optimizer is not None:
save_optimizer_state(optimizer, save_dir, optim_state_dict=optim_state_dict) save_optimizer_state(optimizer, save_dir)
if scheduler is not None: if scheduler is not None:
save_scheduler_state(scheduler, save_dir) save_scheduler_state(scheduler, save_dir)
def load_training_state( def load_training_state(
checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None, load_optimizer: bool = True checkpoint_dir: Path, optimizer: Optimizer, scheduler: LRScheduler | None
) -> tuple[int, Optimizer, LRScheduler | None]: ) -> tuple[int, Optimizer, LRScheduler | None]:
""" """
Loads the training step, optimizer state, scheduler state, and rng state. Loads the training step, optimizer state, scheduler state, and rng state.
@@ -207,10 +186,6 @@ def load_training_state(
checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir. checkpoint_dir (Path): The checkpoint directory. Should contain a 'training_state' dir.
optimizer (Optimizer): The optimizer to load the state_dict to. optimizer (Optimizer): The optimizer to load the state_dict to.
scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None). scheduler (LRScheduler | None): The scheduler to load the state_dict to (can be None).
load_optimizer (bool, optional): Whether to load the optimizer state from disk. Defaults to
True. Set to False under FSDP, where the sharded optimizer state must be loaded after
`accelerator.prepare()` via `load_fsdp_optimizer_state` (the optimizer is returned
untouched here).
Raises: Raises:
NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir NotADirectoryError: If 'checkpoint_dir' doesn't contain a 'training_state' dir
@@ -225,61 +200,8 @@ def load_training_state(
load_rng_state(training_state_dir) load_rng_state(training_state_dir)
step = load_training_step(training_state_dir) step = load_training_step(training_state_dir)
if load_optimizer: optimizer = load_optimizer_state(optimizer, training_state_dir)
optimizer = load_optimizer_state(optimizer, training_state_dir)
if scheduler is not None: if scheduler is not None:
scheduler = load_scheduler_state(scheduler, training_state_dir) scheduler = load_scheduler_state(scheduler, training_state_dir)
return step, optimizer, scheduler return step, optimizer, scheduler
def gather_fsdp_state_dicts(model, optimizer) -> tuple[dict, dict]:
"""Gather the full (unsharded) model and optimizer state dicts under FSDP.
`model.state_dict()` and `FSDP.optim_state_dict(...)` are cross-rank collectives, so this must be
called on *every* rank with the prepared (FSDP-wrapped) `model` and `optimizer`. With
`rank0_only=True` and `offload_to_cpu=True`, every rank runs the all-gather but only rank 0
materializes the full dicts (the others get empty dicts) and they are kept on CPU to bound GPU
memory. The returned optimizer state dict is keyed by parameter FQNs and is world-size
independent; `load_fsdp_optimizer_state` reshards it on resume.
Returns:
(model_state_dict, optim_state_dict): full dicts on rank 0, empty dicts on other ranks.
"""
from torch.distributed.fsdp import (
FullOptimStateDictConfig,
FullStateDictConfig,
FullyShardedDataParallel as FSDP, # noqa F401
StateDictType,
)
state_cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
optim_cfg = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, state_cfg, optim_cfg):
model_state_dict = model.state_dict()
optim_state_dict = FSDP.optim_state_dict(model, optimizer)
return model_state_dict, optim_state_dict
def load_fsdp_optimizer_state(model, optimizer, checkpoint_dir: Path) -> None:
"""Load the FSDP optimizer state (saved as safetensors) and reshard it into the optimizer.
This is a cross-rank collective and must be called on every rank *after* `accelerator.prepare()`
with the prepared (FSDP-wrapped) `model` and `optimizer`. The saved state is the full,
world-size-independent optimizer state (keyed by parameter FQNs); `FSDP.optim_state_dict_to_load`
reshards it to the current FSDP topology, so resume on a different number of GPUs works.
"""
from torch.distributed.fsdp import (
FullOptimStateDictConfig,
FullStateDictConfig,
FullyShardedDataParallel as FSDP, # noqa F401
StateDictType,
)
# Every rank reads the same full state from the (shared) checkpoint dir, so rank0_only=False.
full_osd = load_optimizer_state_dict(checkpoint_dir / TRAINING_STATE_DIR)
state_cfg = FullStateDictConfig(rank0_only=False)
optim_cfg = FullOptimStateDictConfig(rank0_only=False)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, state_cfg, optim_cfg):
sharded_osd = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=full_osd)
optimizer.load_state_dict(sharded_osd)
+2
View File
@@ -39,6 +39,8 @@ class DatasetConfig:
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion. # This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
return_uint8: bool = False return_uint8: bool = False
streaming: bool = False streaming: bool = False
# Fraction of episodes held out per task for offline evaluation (0.0 = disabled).
eval_split: float = 0.0
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.episodes is not None: if self.episodes is not None:
+6 -1
View File
@@ -100,8 +100,13 @@ class TrainPipelineConfig(HubMixin):
prefetch_factor: int = 4 prefetch_factor: int = 4
persistent_workers: bool = True persistent_workers: bool = True
steps: int = 100_000 steps: int = 100_000
eval_freq: int = 20_000 # Run policy in the simulation environment every N steps to measure reward/success (0 = disabled).
env_eval_freq: int = 20_000
log_freq: int = 200 log_freq: int = 200
# Compute eval loss on held-out episodes every N steps (0 = disabled). Requires eval_split > 0.
eval_steps: int = 0
# Cap on total eval samples, split uniformly across tasks (0 = use all held-out data).
max_eval_samples: int = 0
tolerance_s: float = 1e-4 tolerance_s: float = 1e-4
save_checkpoint: bool = True save_checkpoint: bool = True
# Checkpoint is saved every `save_freq` training iterations and after the last training step. # Checkpoint is saved every `save_freq` training iterations and after the last training step.
+2 -1
View File
@@ -35,7 +35,7 @@ from .dataset_tools import (
remove_feature, remove_feature,
split_dataset, split_dataset,
) )
from .factory import make_dataset, resolve_delta_timestamps from .factory import make_dataset, make_train_eval_datasets, resolve_delta_timestamps
from .image_writer import safe_stop_image_writer from .image_writer import safe_stop_image_writer
from .io_utils import load_episodes, write_stats from .io_utils import load_episodes, write_stats
from .language import ( from .language import (
@@ -89,6 +89,7 @@ __all__ = [
"get_feature_stats", "get_feature_stats",
"load_episodes", "load_episodes",
"make_dataset", "make_dataset",
"make_train_eval_datasets",
"merge_datasets", "merge_datasets",
"modify_features", "modify_features",
"modify_tasks", "modify_tasks",
+79
View File
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import math
from pprint import pformat from pprint import pformat
import torch import torch
@@ -130,3 +131,81 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32) dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
return dataset return dataset
def make_train_eval_datasets(
cfg: TrainPipelineConfig,
) -> tuple[LeRobotDataset | MultiLeRobotDataset, LeRobotDataset | None]:
"""Create train and optional eval datasets by splitting episodes based on eval_split.
The last ceil(n_episodes * eval_split) episodes per task are held out for evaluation.
If eval_split == 0.0, returns (full_dataset, None).
"""
full_dataset = make_dataset(cfg)
if cfg.dataset.eval_split == 0.0:
return full_dataset, None
base_episodes = (
full_dataset.episodes if full_dataset.episodes is not None else list(range(full_dataset.num_episodes))
)
episode_tasks = full_dataset.meta.episodes["tasks"]
task_to_episodes: dict[str, list[int]] = {}
for ep_idx in base_episodes:
task_key = episode_tasks[ep_idx][0] if episode_tasks[ep_idx] else ""
task_to_episodes.setdefault(task_key, []).append(ep_idx)
train_episodes, eval_episodes = [], []
for eps in task_to_episodes.values():
n_eval = math.ceil(len(eps) * cfg.dataset.eval_split)
train_episodes.extend(eps[: len(eps) - n_eval])
eval_episodes.extend(eps[len(eps) - n_eval :])
if not train_episodes:
raise ValueError(
f"eval_split={cfg.dataset.eval_split} leaves 0 training episodes from {len(base_episodes)} total."
)
logging.info(
f"Train/eval split: {len(train_episodes)} train, {len(eval_episodes)} eval "
f"(eval_split={cfg.dataset.eval_split}, {len(task_to_episodes)} tasks)"
)
delta_timestamps = resolve_delta_timestamps(cfg.trainable_config, full_dataset.meta)
train_image_transforms = (
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
)
train_dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=train_episodes,
delta_timestamps=delta_timestamps,
image_transforms=train_image_transforms,
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
return_uint8=True,
tolerance_s=cfg.tolerance_s,
)
eval_dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=eval_episodes,
delta_timestamps=delta_timestamps,
image_transforms=None,
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
return_uint8=True,
tolerance_s=cfg.tolerance_s,
)
if cfg.dataset.use_imagenet_stats:
for ds in (train_dataset, eval_dataset):
for key in ds.meta.camera_keys:
for stats_type, stats in IMAGENET_STATS.items():
ds.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
return train_dataset, eval_dataset
+12
View File
@@ -370,6 +370,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
self.reader.load_and_activate() self.reader.load_and_activate()
return self.reader.hf_dataset return self.reader.hf_dataset
@property
def absolute_to_relative_idx(self) -> dict[int, int] | None:
"""Mapping from absolute frame indices to HF dataset row positions.
Non-None only for episode-filtered datasets where absolute indices
(from metadata) differ from row positions in the loaded HF dataset.
"""
reader = self._ensure_reader()
if reader.hf_dataset is None:
reader.load_and_activate()
return reader._absolute_to_relative_idx
# ── Writer-delegated methods ────────────────────────────────────── # ── Writer-delegated methods ──────────────────────────────────────
def add_frame(self, frame: dict) -> None: def add_frame(self, frame: dict) -> None:
+6 -1
View File
@@ -53,6 +53,7 @@ class EpisodeAwareSampler:
drop_n_last_frames: int = 0, drop_n_last_frames: int = 0,
shuffle: bool = False, shuffle: bool = False,
seed: int = 0, seed: int = 0,
absolute_to_relative_idx: dict[int, int] | None = None,
): ):
""" """
Args: Args:
@@ -107,6 +108,7 @@ class EpisodeAwareSampler:
self.seed = seed self.seed = seed
self._epoch = 0 self._epoch = 0
self._start_index = 0 self._start_index = 0
self._absolute_to_relative = absolute_to_relative_idx
@property @property
def indices(self) -> list[int]: def indices(self) -> list[int]:
@@ -132,7 +134,10 @@ class EpisodeAwareSampler:
def _frame_index(self, position: int) -> int: def _frame_index(self, position: int) -> int:
episode = int(np.searchsorted(self._cum_lengths, position, side="right")) episode = int(np.searchsorted(self._cum_lengths, position, side="right"))
position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0) position_in_episode = position - (int(self._cum_lengths[episode - 1]) if episode > 0 else 0)
return int(self._starts[episode]) + position_in_episode absolute_idx = int(self._starts[episode]) + position_in_episode
if self._absolute_to_relative is not None:
return self._absolute_to_relative[absolute_idx]
return absolute_idx
def __iter__(self) -> Iterator[int]: def __iter__(self) -> Iterator[int]:
# Advance epoch state eagerly, not on first consumption of the generator. # Advance epoch state eagerly, not on first consumption of the generator.
-2
View File
@@ -20,7 +20,6 @@ from .optimizers import (
SGDConfig as SGDConfig, SGDConfig as SGDConfig,
XVLAAdamWConfig as XVLAAdamWConfig, XVLAAdamWConfig as XVLAAdamWConfig,
load_optimizer_state, load_optimizer_state,
load_optimizer_state_dict,
save_optimizer_state, save_optimizer_state,
) )
from .schedulers import ( from .schedulers import (
@@ -51,7 +50,6 @@ __all__ = [
"VQBeTSchedulerConfig", "VQBeTSchedulerConfig",
# State management # State management
"load_optimizer_state", "load_optimizer_state",
"load_optimizer_state_dict",
"load_scheduler_state", "load_scheduler_state",
"save_optimizer_state", "save_optimizer_state",
"save_scheduler_state", "save_scheduler_state",
+5 -30
View File
@@ -27,7 +27,7 @@ from lerobot.utils.constants import (
OPTIMIZER_PARAM_GROUPS, OPTIMIZER_PARAM_GROUPS,
OPTIMIZER_STATE, OPTIMIZER_STATE,
) )
from lerobot.utils.io_utils import deserialize_json_into_object, load_json, write_json from lerobot.utils.io_utils import deserialize_json_into_object, write_json
from lerobot.utils.utils import flatten_dict, unflatten_dict from lerobot.utils.utils import flatten_dict, unflatten_dict
# Type alias for parameters accepted by optimizer build() methods. # Type alias for parameters accepted by optimizer build() methods.
@@ -281,37 +281,28 @@ class MultiAdamConfig(OptimizerConfig):
def save_optimizer_state( def save_optimizer_state(
optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], optimizer: torch.optim.Optimizer | dict[str, torch.optim.Optimizer], save_dir: Path
save_dir: Path,
optim_state_dict: dict | None = None,
) -> None: ) -> None:
"""Save optimizer state to disk. """Save optimizer state to disk.
Args: Args:
optimizer: Either a single optimizer or a dictionary of optimizers. optimizer: Either a single optimizer or a dictionary of optimizers.
save_dir: Directory to save the optimizer state. save_dir: Directory to save the optimizer state.
optim_state_dict: Pre-gathered optimizer state dict (for FSDP, where the sharded state must
be gathered across ranks first). If provided, it is saved directly instead of calling
``optimizer.state_dict()``. Only supported for a single optimizer. Defaults to None.
""" """
if isinstance(optimizer, dict): if isinstance(optimizer, dict):
# Handle dictionary of optimizers # Handle dictionary of optimizers
if optim_state_dict is not None:
raise ValueError("optim_state_dict is not supported for a dict of optimizers")
for name, opt in optimizer.items(): for name, opt in optimizer.items():
optimizer_dir = save_dir / name optimizer_dir = save_dir / name
optimizer_dir.mkdir(exist_ok=True, parents=True) optimizer_dir.mkdir(exist_ok=True, parents=True)
_save_single_optimizer_state(opt, optimizer_dir) _save_single_optimizer_state(opt, optimizer_dir)
else: else:
# Handle single optimizer # Handle single optimizer
_save_single_optimizer_state(optimizer, save_dir, optim_state_dict=optim_state_dict) _save_single_optimizer_state(optimizer, save_dir)
def _save_single_optimizer_state( def _save_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Path) -> None:
optimizer: torch.optim.Optimizer, save_dir: Path, optim_state_dict: dict | None = None
) -> None:
"""Save a single optimizer's state to disk.""" """Save a single optimizer's state to disk."""
state = dict(optim_state_dict) if optim_state_dict is not None else optimizer.state_dict() state = optimizer.state_dict()
param_groups = state.pop("param_groups") param_groups = state.pop("param_groups")
flat_state = flatten_dict(state) flat_state = flatten_dict(state)
save_file(flat_state, save_dir / OPTIMIZER_STATE) save_file(flat_state, save_dir / OPTIMIZER_STATE)
@@ -365,19 +356,3 @@ def _load_single_optimizer_state(optimizer: torch.optim.Optimizer, save_dir: Pat
optimizer.load_state_dict(loaded_state_dict) optimizer.load_state_dict(loaded_state_dict)
return optimizer return optimizer
def load_optimizer_state_dict(save_dir: Path) -> dict:
"""Read a saved optimizer state dict (safetensors + json) back into a plain dict.
Unlike `load_optimizer_state`, this does not load into an optimizer and preserves the original
``state`` keys verbatim (e.g. FSDP parameter FQNs, which are not integer-castable). It is used by
the FSDP resume path, where the full state must be resharded via `FSDP.optim_state_dict_to_load`
before being loaded into the (sharded) optimizer.
"""
flat_state = load_file(save_dir / OPTIMIZER_STATE)
state = unflatten_dict(flat_state)
return {
"state": state.get("state", {}),
"param_groups": load_json(save_dir / OPTIMIZER_PARAM_GROUPS),
}
+4 -39
View File
@@ -23,7 +23,7 @@ from typing import TypedDict, TypeVar, Unpack
import packaging import packaging
import safetensors import safetensors
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download, save_torch_state_dict from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError from huggingface_hub.errors import HfHubHTTPError
from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor from safetensors.torch import load_model as load_model_as_safetensor, save_model as save_model_as_safetensor
@@ -129,43 +129,10 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
if not getattr(cls, "name", None): if not getattr(cls, "name", None):
raise TypeError(f"Class {cls.__name__} must define 'name'") raise TypeError(f"Class {cls.__name__} must define 'name'")
def save_pretrained( def _save_pretrained(self, save_directory: Path) -> None:
self,
save_directory: str | Path,
*,
state_dict: dict[str, Tensor] | None = None,
repo_id: str | None = None,
push_to_hub: bool = False,
card_kwargs: dict | None = None,
**push_to_hub_kwargs,
) -> str | None:
"""Save the policy to a directory (and optionally push to the Hub).
Overrides `HubMixin.save_pretrained` to add a `state_dict` argument (mirroring
`transformers.PreTrainedModel.save_pretrained`). Under FSDP, `self.state_dict()` would
return sharded tensors, so the caller gathers the full state dict via a cross-rank
collective and passes it here for `_save_pretrained` to write directly.
"""
save_directory = Path(save_directory)
save_directory.mkdir(parents=True, exist_ok=True)
self._save_pretrained(save_directory, state_dict=state_dict)
if push_to_hub:
if repo_id is None:
repo_id = save_directory.name
return self.push_to_hub(repo_id=repo_id, card_kwargs=card_kwargs, **push_to_hub_kwargs)
return None
def _save_pretrained(self, save_directory: Path, state_dict: dict[str, Tensor] | None = None) -> None:
self.config._save_pretrained(save_directory) self.config._save_pretrained(save_directory)
model_to_save = self.module if hasattr(self, "module") else self model_to_save = self.module if hasattr(self, "module") else self
if state_dict is None: save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
save_model_as_safetensor(model_to_save, str(save_directory / SAFETENSORS_SINGLE_FILE))
return
# A pre-gathered (e.g. FSDP full) state dict was supplied: write it directly.
# `save_torch_state_dict` discards shared-tensor duplicates just like `save_model` does;
# pin `max_shard_size` above the total size so the output stays a single `model.safetensors`
total_bytes = sum(t.numel() * t.element_size() for t in state_dict.values())
save_torch_state_dict(state_dict, str(save_directory), max_shard_size=max(total_bytes, 1))
@classmethod @classmethod
def from_pretrained( def from_pretrained(
@@ -303,7 +270,6 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
self, self,
cfg: TrainPipelineConfig, cfg: TrainPipelineConfig,
peft_model=None, peft_model=None,
state_dict: dict[str, Tensor] | None = None,
): ):
api = HfApi() api = HfApi()
repo_id = api.create_repo( repo_id = api.create_repo(
@@ -321,8 +287,7 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
peft_model.save_pretrained(saved_path) peft_model.save_pretrained(saved_path)
self.config.save_pretrained(saved_path) self.config.save_pretrained(saved_path)
else: else:
# Calls _save_pretrained and stores model tensors self.save_pretrained(saved_path) # Calls _save_pretrained and stores model tensors
self.save_pretrained(saved_path, state_dict=state_dict)
card = self.generate_model_card( card = self.generate_model_card(
cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg cfg.dataset.repo_id, self.config.type, self.config.license, self.config.tags, cfg=cfg
+61 -34
View File
@@ -34,10 +34,8 @@ from torch.optim import Optimizer
from tqdm import tqdm from tqdm import tqdm
from lerobot.common.train_utils import ( from lerobot.common.train_utils import (
gather_fsdp_state_dicts,
get_step_checkpoint_dir, get_step_checkpoint_dir,
get_step_identifier, get_step_identifier,
load_fsdp_optimizer_state,
load_training_batch_size, load_training_batch_size,
load_training_num_processes, load_training_num_processes,
load_training_state, load_training_state,
@@ -47,7 +45,8 @@ from lerobot.common.train_utils import (
from lerobot.common.wandb_utils import WandBLogger from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state, make_dataset from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state
from lerobot.datasets.factory import make_train_eval_datasets
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
@@ -191,7 +190,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
require_package("accelerate", extra="training") require_package("accelerate", extra="training")
from accelerate import Accelerator from accelerate import Accelerator
from accelerate.utils import DistributedDataParallelKwargs, DistributedType
cfg.validate() cfg.validate()
@@ -200,6 +198,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes # We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
# We set find_unused_parameters=True to handle models with conditional computation # We set find_unused_parameters=True to handle models with conditional computation
if accelerator is None: if accelerator is None:
from accelerate.utils import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting. # Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
# Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training). # Force the device to be CPU when the active config's device is set to CPU (works for both policy and reward model training).
@@ -245,19 +245,19 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads. # LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads.
if is_main_process: if is_main_process:
logging.info("Creating dataset") logging.info("Creating dataset")
dataset = make_dataset(cfg) dataset, eval_dataset = make_train_eval_datasets(cfg)
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
# Other ranks read from the shared copy populated by the main process. # Other ranks read from the shared copy populated by the main process.
if not is_main_process: if not is_main_process:
dataset = make_dataset(cfg) dataset, eval_dataset = make_train_eval_datasets(cfg)
# Create environment used for evaluating checkpoints during training on simulation data. # Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py, # On real-world data, no need to create an environment as evaluations are done outside train.py,
# using the eval.py instead, with gym_dora environment and dora-rs. # using the eval.py instead, with gym_dora environment and dora-rs.
eval_env = None eval_env = None
if cfg.eval_freq > 0 and cfg.env is not None and is_main_process: if cfg.env_eval_freq > 0 and cfg.env is not None and is_main_process:
logging.info("Creating env") logging.info("Creating env")
eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs) eval_env = make_env(cfg.env, n_envs=cfg.eval.batch_size, use_async_envs=cfg.eval.use_async_envs)
@@ -371,12 +371,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
step = 0 # number of policy updates (forward + backward + optim) step = 0 # number of policy updates (forward + backward + optim)
if cfg.resume: if cfg.resume:
# Under FSDP the optimizer state is sharded and must be loaded after `accelerator.prepare()` step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
# (see load_fsdp_optimizer_state below), so skip the optimizer here and load it then.
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
step, optimizer, lr_scheduler = load_training_state(
cfg.checkpoint_path, optimizer, lr_scheduler, load_optimizer=not is_fsdp
)
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_total_params = sum(p.numel() for p in policy.parameters()) num_total_params = sum(p.numel() for p in policy.parameters())
@@ -412,6 +407,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0), drop_n_last_frames=getattr(active_cfg, "drop_n_last_frames", 0),
shuffle=True, shuffle=True,
seed=cfg.seed if cfg.seed is not None else 0, seed=cfg.seed if cfg.seed is not None else 0,
absolute_to_relative_idx=dataset.absolute_to_relative_idx,
) )
if cfg.resume and step > 0: if cfg.resume and step > 0:
# The resume offset depends on the (num_processes, batch_size) that produced `step`, so # The resume offset depends on the (num_processes, batch_size) that produced `step`, so
@@ -461,17 +457,38 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0, persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
) )
# Build eval dataloader if a held-out split exists
eval_dataloader = None
if eval_dataset is not None:
eval_ds = eval_dataset
if cfg.max_eval_samples > 0 and hasattr(eval_dataset, "hf_dataset"):
task_arr = eval_dataset.hf_dataset.data.column("task_index").to_numpy()
unique_tasks = sorted(set(task_arr.tolist()))
per_task = max(1, cfg.max_eval_samples // len(unique_tasks))
selected: list[int] = []
for t in unique_tasks:
frames = (task_arr == t).nonzero()[0][:per_task]
selected.extend(frames.tolist())
eval_ds = torch.utils.data.Subset(eval_dataset, selected)
eval_collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None
eval_dataloader = torch.utils.data.DataLoader(
eval_ds,
batch_size=cfg.batch_size,
shuffle=False,
num_workers=cfg.num_workers,
pin_memory=device.type == "cuda",
drop_last=False,
collate_fn=eval_collate_fn,
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
)
# Prepare everything with accelerator # Prepare everything with accelerator
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare( policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
policy, optimizer, dataloader, lr_scheduler policy, optimizer, dataloader, lr_scheduler
) )
# FSDP optimizer state is sharded across ranks, so it can only be loaded once the optimizer and
# model are FSDP-wrapped (i.e. after `prepare`). Collective: every rank must participate.
if cfg.resume and accelerator.distributed_type == DistributedType.FSDP:
load_fsdp_optimizer_state(policy, optimizer, cfg.checkpoint_path)
dl_iter = cycle(dataloader) dl_iter = cycle(dataloader)
policy.train() policy.train()
@@ -546,7 +563,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
train_tracker.step() train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 is_env_eval_step = cfg.env_eval_freq > 0 and step % cfg.env_eval_freq == 0
is_eval_step = cfg.eval_steps > 0 and eval_dataloader is not None and step % cfg.eval_steps == 0
if is_log_step: if is_log_step:
# Collective reduce must run on every rank, before the main-process gate below. # Collective reduce must run on every rank, before the main-process gate below.
@@ -569,15 +587,28 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
wandb_logger.log_dict(wandb_log_dict, step) wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages() train_tracker.reset_averages()
if is_eval_step:
policy.eval()
eval_loss_sum = 0.0
n_eval_batches = 0
with torch.no_grad(), accelerator.autocast():
for eval_batch in eval_dataloader:
for cam_key in dataset.meta.camera_keys:
if cam_key in eval_batch and eval_batch[cam_key].dtype == torch.uint8:
eval_batch[cam_key] = eval_batch[cam_key].to(dtype=torch.float32) / 255.0
eval_batch = preprocessor(eval_batch)
loss, _ = policy.forward(eval_batch)
eval_loss_sum += loss.item()
n_eval_batches += 1
eval_loss = eval_loss_sum / max(n_eval_batches, 1)
policy.train()
if is_main_process:
logging.info(f"step {step}: eval_loss={eval_loss:.4f}")
if wandb_logger:
wandb_logger.log_dict({"eval_loss": eval_loss}, step=step, mode="eval")
if cfg.save_checkpoint and is_saving_step: if cfg.save_checkpoint and is_saving_step:
# Under FSDP, gathering the full model + optimizer state dicts is a cross-rank collective,
# so all ranks must participate; rank 0 then writes the materialized dicts. For DDP /
# single-GPU the state dicts are saved the normal way inside save_checkpoint.
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
if is_fsdp:
model_state_dict, optim_state_dict = gather_fsdp_state_dicts(policy, optimizer)
else:
model_state_dict, optim_state_dict = None, None
if is_main_process: if is_main_process:
logging.info(f"Checkpoint policy after step {step}") logging.info(f"Checkpoint policy after step {step}")
checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step)
@@ -592,8 +623,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
postprocessor=postprocessor, postprocessor=postprocessor,
num_processes=accelerator.num_processes, num_processes=accelerator.num_processes,
batch_size=cfg.batch_size, batch_size=cfg.batch_size,
model_state_dict=model_state_dict,
optim_state_dict=optim_state_dict,
) )
update_last_checkpoint(checkpoint_dir) update_last_checkpoint(checkpoint_dir)
if wandb_logger: if wandb_logger:
@@ -601,7 +630,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
accelerator.wait_for_everyone() accelerator.wait_for_everyone()
if cfg.env and is_eval_step: if cfg.env and is_env_eval_step:
if is_main_process: if is_main_process:
step_id = get_step_identifier(step, cfg.steps) step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}") logging.info(f"Eval policy at step {step}")
@@ -656,8 +685,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if eval_env: if eval_env:
close_envs(eval_env) close_envs(eval_env)
is_fsdp = accelerator.distributed_type == DistributedType.FSDP
model_state_dict = accelerator.get_state_dict(policy) if is_fsdp else None
if is_main_process: if is_main_process:
logging.info("End of training") logging.info("End of training")
@@ -667,7 +694,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if not cfg.is_reward_model_training and cfg.policy.use_peft: if not cfg.is_reward_model_training and cfg.policy.use_peft:
unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model) unwrapped_model.push_model_to_hub(cfg, peft_model=unwrapped_model)
else: else:
unwrapped_model.push_model_to_hub(cfg, state_dict=model_state_dict) unwrapped_model.push_model_to_hub(cfg)
preprocessor.push_to_hub(active_cfg.repo_id) preprocessor.push_to_hub(active_cfg.repo_id)
postprocessor.push_to_hub(active_cfg.repo_id) postprocessor.push_to_hub(active_cfg.repo_id)
-39
View File
@@ -20,7 +20,6 @@ from lerobot.optim.optimizers import (
MultiAdamConfig, MultiAdamConfig,
SGDConfig, SGDConfig,
load_optimizer_state, load_optimizer_state,
load_optimizer_state_dict,
save_optimizer_state, save_optimizer_state,
) )
from lerobot.utils.constants import ( from lerobot.utils.constants import (
@@ -66,44 +65,6 @@ def test_save_and_load_optimizer_state(model_params, optimizer, tmp_path):
torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict()) torch.testing.assert_close(optimizer.state_dict(), loaded_optimizer.state_dict())
def test_save_and_load_fsdp_optimizer_state_dict_roundtrip(tmp_path):
"""The FSDP full optimizer state dict is keyed by parameter FQNs (dotted strings), not the
integer indices of the single-GPU path. Verify it survives the safetensors save -> read
round-trip used by the FSDP save/resume path (save_optimizer_state(optim_state_dict=...) then
load_optimizer_state_dict), which the flatten/unflatten "/" separator must not corrupt."""
full_osd = {
"state": {
"model.layers.0.weight": {
"step": torch.tensor(3.0),
"exp_avg": torch.randn(4, 4),
"exp_avg_sq": torch.randn(4, 4),
},
"model.layers.0.bias": {
"step": torch.tensor(3.0),
"exp_avg": torch.randn(4),
"exp_avg_sq": torch.randn(4),
},
},
"param_groups": [
{"lr": 1e-4, "betas": [0.9, 0.999], "eps": 1e-8, "weight_decay": 0.0, "params": [0, 1]}
],
}
save_optimizer_state(
torch.optim.Adam([torch.nn.Parameter(torch.randn(1))]), tmp_path, optim_state_dict=full_osd
)
assert (tmp_path / OPTIMIZER_STATE).is_file()
assert (tmp_path / OPTIMIZER_PARAM_GROUPS).is_file()
loaded = load_optimizer_state_dict(tmp_path)
# FQN keys must be preserved verbatim (not int-cast, not split on their dots).
assert set(loaded["state"].keys()) == set(full_osd["state"].keys())
for fqn, sub in full_osd["state"].items():
for k, v in sub.items():
torch.testing.assert_close(loaded["state"][fqn][k], v)
assert loaded["param_groups"] == full_osd["param_groups"]
@pytest.fixture @pytest.fixture
def base_params_dict(): def base_params_dict():
return { return {
-24
View File
@@ -23,7 +23,6 @@ import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from packaging import version from packaging import version
from safetensors.torch import load_file from safetensors.torch import load_file
@@ -301,29 +300,6 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0) torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
def test_save_pretrained_with_state_dict(dummy_dataset_metadata, tmp_path):
"""Exercise the FSDP checkpoint path: save_pretrained with a pre-gathered state_dict."""
policy_cls = get_policy_class("act")
policy_cfg = make_policy_config("act")
features = dataset_to_policy_features(dummy_dataset_metadata.features)
policy_cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
policy_cfg.input_features = {
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
}
policy = policy_cls(policy_cfg)
policy.to(policy_cfg.device)
save_dir = tmp_path / "fsdp_state_dict"
policy.save_pretrained(save_dir, state_dict=policy.state_dict())
# A single, unsharded safetensors file (no sharded set + index).
assert (save_dir / SAFETENSORS_SINGLE_FILE).is_file()
assert not (save_dir / f"{SAFETENSORS_SINGLE_FILE}.index.json").exists()
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
@pytest.mark.parametrize("multikey", [True, False]) @pytest.mark.parametrize("multikey", [True, False])
def test_multikey_construction(multikey: bool): def test_multikey_construction(multikey: bool):
""" """
+15 -110
View File
@@ -58,46 +58,7 @@ def download_dataset(repo_id, episodes):
print(f"Dataset {repo_id} downloaded successfully") print(f"Dataset {repo_id} downloaded successfully")
def _write_multi_gpu_config(f, num_processes): def run_accelerate_training(config_args, num_processes=4, temp_dir=None):
f.write("compute_environment: LOCAL_MACHINE\n")
f.write("distributed_type: MULTI_GPU\n")
f.write("mixed_precision: 'no'\n")
f.write(f"num_processes: {num_processes}\n")
f.write("use_cpu: false\n")
f.write("gpu_ids: all\n")
f.write("downcast_bf16: 'no'\n")
f.write("machine_rank: 0\n")
f.write("main_training_function: main\n")
f.write("num_machines: 1\n")
f.write("rdzv_backend: static\n")
f.write("same_network: true\n")
def _write_fsdp_config(f, num_processes):
# FSDP1 with FULL_SHARD (ZeRO-3-equivalent) and FULL_STATE_DICT, matching
# docs/source/multi_gpu_training.mdx. ACT's repeated transformer blocks are the wrap units;
# fsdp_use_orig_params is required because LeRobot builds the optimizer before prepare().
f.write("compute_environment: LOCAL_MACHINE\n")
f.write("distributed_type: FSDP\n")
f.write("mixed_precision: 'no'\n")
f.write(f"num_processes: {num_processes}\n")
f.write("use_cpu: false\n")
f.write("gpu_ids: all\n")
f.write("machine_rank: 0\n")
f.write("main_training_function: main\n")
f.write("num_machines: 1\n")
f.write("rdzv_backend: static\n")
f.write("same_network: true\n")
f.write("fsdp_config:\n")
f.write(" fsdp_version: 1\n")
f.write(" fsdp_sharding_strategy: FULL_SHARD\n")
f.write(" fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP\n")
f.write(" fsdp_transformer_layer_cls_to_wrap: ACTEncoderLayer,ACTDecoderLayer\n")
f.write(" fsdp_use_orig_params: true\n")
f.write(" fsdp_state_dict_type: FULL_STATE_DICT\n")
def run_accelerate_training(config_args, num_processes=4, temp_dir=None, distributed_type="MULTI_GPU"):
""" """
Helper function to run training with accelerate launch. Helper function to run training with accelerate launch.
@@ -105,7 +66,6 @@ def run_accelerate_training(config_args, num_processes=4, temp_dir=None, distrib
config_args: List of config arguments to pass to lerobot_train.py config_args: List of config arguments to pass to lerobot_train.py
num_processes: Number of processes (GPUs) to use num_processes: Number of processes (GPUs) to use
temp_dir: Temporary directory for outputs temp_dir: Temporary directory for outputs
distributed_type: "MULTI_GPU" (DDP) or "FSDP" selects the generated accelerate config.
Returns: Returns:
subprocess.CompletedProcess result subprocess.CompletedProcess result
@@ -115,10 +75,18 @@ def run_accelerate_training(config_args, num_processes=4, temp_dir=None, distrib
# Write YAML config # Write YAML config
with open(config_path, "w") as f: with open(config_path, "w") as f:
if distributed_type == "FSDP": f.write("compute_environment: LOCAL_MACHINE\n")
_write_fsdp_config(f, num_processes) f.write("distributed_type: MULTI_GPU\n")
else: f.write("mixed_precision: 'no'\n")
_write_multi_gpu_config(f, num_processes) f.write(f"num_processes: {num_processes}\n")
f.write("use_cpu: false\n")
f.write("gpu_ids: all\n")
f.write("downcast_bf16: 'no'\n")
f.write("machine_rank: 0\n")
f.write("main_training_function: main\n")
f.write("num_machines: 1\n")
f.write("rdzv_backend: static\n")
f.write("same_network: true\n")
cmd = [ cmd = [
"accelerate", "accelerate",
@@ -166,7 +134,7 @@ class TestMultiGPUTraining:
f"--output_dir={output_dir}", f"--output_dir={output_dir}",
"--batch_size=4", "--batch_size=4",
"--steps=10", "--steps=10",
"--eval_freq=-1", "--env_eval_freq=-1",
"--log_freq=5", "--log_freq=5",
"--save_freq=10", "--save_freq=10",
"--seed=42", "--seed=42",
@@ -209,7 +177,7 @@ class TestMultiGPUTraining:
f"--output_dir={output_dir}", f"--output_dir={output_dir}",
"--batch_size=4", "--batch_size=4",
"--steps=20", "--steps=20",
"--eval_freq=-1", "--env_eval_freq=-1",
"--log_freq=5", "--log_freq=5",
"--save_freq=10", "--save_freq=10",
"--seed=42", "--seed=42",
@@ -243,66 +211,3 @@ class TestMultiGPUTraining:
# Verify optimizer state exists # Verify optimizer state exists
optimizer_state = training_state_dir / "optimizer_state.safetensors" optimizer_state = training_state_dir / "optimizer_state.safetensors"
assert optimizer_state.exists(), f"No optimizer state in checkpoint {checkpoint_dir}" assert optimizer_state.exists(), f"No optimizer state in checkpoint {checkpoint_dir}"
def test_fsdp_optimizer_save_and_resume(self):
"""
Test that FSDP saves the (gathered) optimizer state and can resume from it.
Trains a few steps under FSDP, verifies the gathered optimizer state is written next to the
rest of the training state, then resumes from the checkpoint for more steps and checks it
completes without shape/key errors in the FSDP optimizer load path.
"""
# Pre-download dataset to avoid race conditions
download_dataset("lerobot/pusht", episodes=[0])
with tempfile.TemporaryDirectory() as temp_dir:
output_dir = Path(temp_dir) / "outputs"
config_args = [
"--dataset.repo_id=lerobot/pusht",
"--dataset.episodes=[0]",
"--policy.type=act",
"--policy.device=cuda",
"--policy.push_to_hub=false",
f"--output_dir={output_dir}",
"--batch_size=4",
"--steps=10",
"--eval_freq=-1",
"--log_freq=5",
"--save_freq=10",
"--seed=42",
"--num_workers=0",
]
result = run_accelerate_training(
config_args, num_processes=2, temp_dir=temp_dir, distributed_type="FSDP"
)
assert result.returncode == 0, (
f"FSDP training failed:\nSTDOUT:\n{result.stdout}\n\nSTDERR:\n{result.stderr}"
)
# The gathered optimizer state must be written under FSDP (proves the save collective ran),
# in the same safetensors format as single-GPU training.
training_state_dir = output_dir / "checkpoints" / "last" / "training_state"
optimizer_state = training_state_dir / "optimizer_state.safetensors"
optimizer_param_groups = training_state_dir / "optimizer_param_groups.json"
assert optimizer_state.exists(), f"FSDP optimizer state not saved in {training_state_dir}"
assert optimizer_param_groups.exists(), (
f"FSDP optimizer param groups not saved in {training_state_dir}"
)
# Resume from the checkpoint for more steps. A successful run proves load_fsdp_optimizer
# accepts the saved state and reshards it without shape/key errors.
resume_config = output_dir / "checkpoints" / "last" / "pretrained_model" / "train_config.json"
resume_args = [
f"--config_path={resume_config}",
"--resume=true",
"--steps=20",
]
resume_result = run_accelerate_training(
resume_args, num_processes=2, temp_dir=temp_dir, distributed_type="FSDP"
)
assert resume_result.returncode == 0, (
f"FSDP resume failed:\nSTDOUT:\n{resume_result.stdout}\n\nSTDERR:\n{resume_result.stderr}"
)
assert "End of training" in resume_result.stdout or "End of training" in resume_result.stderr
-15
View File
@@ -136,18 +136,3 @@ def test_save_load_training_state(tmp_path, optimizer, scheduler):
assert loaded_step == 10 assert loaded_step == 10
assert loaded_optimizer is optimizer assert loaded_optimizer is optimizer
assert loaded_scheduler is scheduler assert loaded_scheduler is scheduler
def test_load_training_state_skip_optimizer(tmp_path, optimizer, scheduler):
# FSDP loads optimizer separately (after accelerator.prepare)
# load_training_state(load_optimizer=False) must restore step + scheduler but leave the
# optimizer untouched and never touch the on-disk optimizer state.
save_training_state(tmp_path, 10, optimizer, scheduler)
with patch("lerobot.common.train_utils.load_optimizer_state") as mock_load_optimizer_state:
loaded_step, loaded_optimizer, loaded_scheduler = load_training_state(
tmp_path, optimizer, scheduler, load_optimizer=False
)
mock_load_optimizer_state.assert_not_called()
assert loaded_step == 10
assert loaded_optimizer is optimizer
assert loaded_scheduler is scheduler