ema: enable by default (matches openpi JAX behavior)

Flip EMAConfig.enable default from False -> True. Every training run
now maintains an EMA shadow of the policy and uses it for eval + W&B
example dumps. Disable per-run with --ema.enable=false for short or
memory-constrained training.

Rationale:
  * openpi (JAX, official) ships EMA on for every shipped config,
    decay=0.99 by default and 0.999 for pi05_libero. The openpi
    PyTorch port explicitly lists EMA as unsupported, a gap LeRobot
    main inherited. Flipping the default closes that gap for every
    LeRobot policy that ships through lerobot-train.
  * EMA is established best practice for diffusion / flow-matching
    policies (Diffusion Policy §V.D; standard in DDPM/EDM/Stable
    Diffusion training recipes). For autoregressive policies the
    extra cost is real but the safety net (smoother eval, better
    final checkpoint) doesn't hurt.

Trade-offs to be aware of:
  * Memory: 1x model params in fp32 shadow (~13 GB for pi052's
    3.3B params; <500 MB for ACT/Diffusion-Policy class). Memory-
    constrained users on consumer GPUs may need --ema.enable=false.
  * Checkpoint disk: extra .pt file in training_state/, size ~=
    pretrained_model/model.safetensors. Over a 100k-step run with
    save_freq=20000 that's 5x the model size in extra disk.
  * Eval scores will now reflect EMA model instead of live model -
    expected to be 1-3% higher on closed-loop tasks per the
    diffusion-policy literature; might surprise users who memorize
    their last run's numbers.

Opt out:
  --ema.enable=false           # disable entirely
  --ema.use_for_eval=false     # keep EMA but eval reflects live
  --ema.use_for_wandb_examples=false   # keep EMA but W&B reflects live

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-25 21:58:46 +02:00
parent 72ea531017
commit 2ed6519a93
+6 -4
View File
@@ -98,12 +98,14 @@ class EMAConfig:
Cost: 1× model params in fp32 shadow (~13 GB for pi052's 3.3B
params) + one elementwise update per training step (~1% step time).
Off by default (back-compat). Recommended for long pi052 training
runs — typically ~13% absolute success-rate improvement on
closed-loop tasks per the diffusion-policy literature.
On by default — matches openpi (JAX) which ships EMA on for every
config, and closes the gap with the openpi PyTorch port which
explicitly lists EMA as unsupported. Set ``--ema.enable=false`` to
disable for short runs / memory-constrained training where the
extra fp32 shadow copy is the bottleneck.
"""
enable: bool = False
enable: bool = True
# Target EMA decay β in θ_ema ← β·θ_ema + (1-β)·θ_live (passed to
# ema-pytorch as ``beta``).
# 0.999 — last ~1000 steps; pi05_libero default in openpi