mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 22:49:48 +00:00
train: switch EMA from custom ModelEMA to ema-pytorch
Replace the 250-line src/lerobot/utils/ema.py with a direct dependency
on ema-pytorch (lucidrains' canonical PyTorch EMA library). Same
semantics, decay=0.999 default unchanged, but offloads the maintenance
burden to a maintained library used by every diffusion repo.
Why ema-pytorch:
* Standard PyTorch EMA library; battle-tested across diffusion +
speech + image-gen codebases.
* Tiny pure-python dep (no compiled code).
* Cleaner consumer-side API: ema.ema_model is a full nn.Module
clone of the policy, so eval / wandb just pass it through instead
of context-managed swap/restore on the live model.
What changed mechanically:
* pyproject.toml: add 'ema-pytorch>=0.7.7,<1.0.0' to core deps.
* deleted src/lerobot/utils/ema.py (the custom ModelEMA).
* scripts/lerobot_train.py:
- import EMA from ema_pytorch
- instantiate with beta=cfg.ema.decay,
update_after_step=cfg.ema.warmup_steps, update_every=1,
include_online_model=False (accelerator owns live model
lifecycle; double-registration would double-count params).
- ema.update() (no args) — library tracks the online model
internally.
- Eval block: pass eval_target_policy = ema.ema_model (when
cfg.ema.use_for_eval) instead of swap context manager.
- W&B examples: same pattern.
- Save: torch.save(ema.state_dict(), .../ema_state.pt) instead
of custom safetensors writer. .pt format is consistent with
the rest of training_state which already mixes safetensors +
json + (now) pt.
- Resume: ema.load_state_dict(torch.load(.../ema_state.pt)).
- WandB observability: ema/step (count of ema.update calls),
ema/initted (bool from library), ema/beta (constant from
cfg).
* configs/default.py: EMAConfig.decay stays 0.999 (matches
openpi's pi05_libero); docstring updated to reflect ema-pytrch
semantics for warmup_steps (now maps to update_after_step — a hard
skip, not a smooth decay ramp).
Behavior preserved:
* Defaults: enable=False, decay=0.999, warmup_steps=0,
use_for_eval=True, use_for_wandb_examples=True.
* Same CLI: --ema.enable=true, --ema.decay=X, etc.
* Same checkpoint layout (training_state/ema_state.pt next to
optimizer_state.safetensors etc.); resumes silently if present.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -85,6 +85,11 @@ dependencies = [
|
||||
"termcolor>=2.4.0,<4.0.0",
|
||||
"tqdm>=4.66.0,<5.0.0",
|
||||
|
||||
# Training utilities
|
||||
# EMA of policy parameters (Diffusion Policy / pi05 style). Tiny
|
||||
# pure-python dependency — preferred over a hand-rolled implementation.
|
||||
"ema-pytorch>=0.7.7,<1.0.0",
|
||||
|
||||
# Build tools (required by opencv-python-headless on some platforms)
|
||||
"cmake>=3.29.0.1,<4.2.0",
|
||||
"setuptools>=71.0.0,<81.0.0",
|
||||
|
||||
Reference in New Issue
Block a user