train: switch EMA from custom ModelEMA to ema-pytorch

Replace the 250-line src/lerobot/utils/ema.py with a direct dependency
on ema-pytorch (lucidrains' canonical PyTorch EMA library). Same
semantics, decay=0.999 default unchanged, but offloads the maintenance
burden to a maintained library used by every diffusion repo.

Why ema-pytorch:
  * Standard PyTorch EMA library; battle-tested across diffusion +
    speech + image-gen codebases.
  * Tiny pure-python dep (no compiled code).
  * Cleaner consumer-side API: ema.ema_model is a full nn.Module
    clone of the policy, so eval / wandb just pass it through instead
    of context-managed swap/restore on the live model.

What changed mechanically:
  * pyproject.toml: add 'ema-pytorch>=0.7.7,<1.0.0' to core deps.
  * deleted src/lerobot/utils/ema.py (the custom ModelEMA).
  * scripts/lerobot_train.py:
      - import EMA from ema_pytorch
      - instantiate with beta=cfg.ema.decay,
        update_after_step=cfg.ema.warmup_steps, update_every=1,
        include_online_model=False (accelerator owns live model
        lifecycle; double-registration would double-count params).
      - ema.update() (no args) — library tracks the online model
        internally.
      - Eval block: pass eval_target_policy = ema.ema_model (when
        cfg.ema.use_for_eval) instead of swap context manager.
      - W&B examples: same pattern.
      - Save: torch.save(ema.state_dict(), .../ema_state.pt) instead
        of custom safetensors writer. .pt format is consistent with
        the rest of training_state which already mixes safetensors +
        json + (now) pt.
      - Resume: ema.load_state_dict(torch.load(.../ema_state.pt)).
      - WandB observability: ema/step (count of ema.update calls),
        ema/initted (bool from library), ema/beta (constant from
        cfg).
  * configs/default.py: EMAConfig.decay stays 0.999 (matches
    openpi's pi05_libero); docstring updated to reflect ema-pytrch
    semantics for warmup_steps (now maps to update_after_step — a hard
    skip, not a smooth decay ramp).

Behavior preserved:
  * Defaults: enable=False, decay=0.999, warmup_steps=0,
    use_for_eval=True, use_for_wandb_examples=True.
  * Same CLI: --ema.enable=true, --ema.decay=X, etc.
  * Same checkpoint layout (training_state/ema_state.pt next to
    optimizer_state.safetensors etc.); resumes silently if present.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-25 21:51:23 +02:00
parent 56a934ec55
commit 72ea531017
4 changed files with 103 additions and 329 deletions
+5
View File
@@ -85,6 +85,11 @@ dependencies = [
"termcolor>=2.4.0,<4.0.0",
"tqdm>=4.66.0,<5.0.0",
# Training utilities
# EMA of policy parameters (Diffusion Policy / pi05 style). Tiny
# pure-python dependency — preferred over a hand-rolled implementation.
"ema-pytorch>=0.7.7,<1.0.0",
# Build tools (required by opencv-python-headless on some platforms)
"cmake>=3.29.0.1,<4.2.0",
"setuptools>=71.0.0,<81.0.0",
+31 -22
View File
@@ -87,35 +87,44 @@ class EMAConfig:
Diffusion / flow-matching policies (Diffusion Policy, π0/π0.5,
pi052) benefit substantially from averaging late-training
parameter oscillations — see Chi et al. 2023 §V.D. EMA adds a
fp32 shadow of the policy (~13 GB for pi052's 3.3B params) and one
elementwise update per training step (~1% step time).
parameter oscillations — see Chi et al. 2023 §V.D. The official
JAX openpi trainer ships EMA with ``ema_decay=0.99`` (default) and
``0.999`` for its pi05_libero config; the openpi PyTorch port
explicitly lists EMA as unsupported, and LeRobot main inherited
that gap. Enabling this flag plugs ema-pytorch
(https://github.com/lucidrains/ema-pytorch) into the LeRobot
training loop with a shadow ``nn.Module`` clone of the policy.
Off by default (back-compat for existing runs). Recommended for
long pi052 training runs where you want closed-loop eval to use
the smoothed weights — typically ~13% absolute success-rate
improvement on closed-loop tasks per the diffusion-policy lit.
Cost: 1× model params in fp32 shadow (~13 GB for pi052's 3.3B
params) + one elementwise update per training step (~1% step time).
Off by default (back-compat). Recommended for long pi052 training
runs — typically ~13% absolute success-rate improvement on
closed-loop tasks per the diffusion-policy literature.
"""
enable: bool = False
# Target EMA decay β in θ_ema ← β·θ_ema + (1-β)·θ_live.
# 0.999 ≈ last 1000 steps (standard; pi05-class default)
# 0.75 ≈ very fast EMA (Diffusion Policy original setting)
# 0.9999 ≈ very slow EMA (long classification runs)
# Target EMA decay β in θ_ema ← β·θ_ema + (1-β)·θ_live (passed to
# ema-pytorch as ``beta``).
# 0.999 — last ~1000 steps; pi05_libero default in openpi
# 0.99 — last ~100 steps; openpi top-level default
# 0.75 — very fast EMA (Diffusion Policy original setting)
# 0.9999 — very slow EMA (long classification runs)
decay: float = 0.999
# If > 0, ramp effective decay up to ``decay`` over the first N
# updates as min(decay, (1+n)/(10+n)). Lets the EMA track rapid
# early-training changes before settling. ``0`` = use ``decay``
# from step 1.
# Skip the first N calls to ``ema.update()``; during this window
# the shadow is just a hard copy of the live weights (no averaging).
# Lets early-training rapid changes settle before averaging begins.
# Maps to ema-pytorch's ``update_after_step`` (NOT a smooth decay
# ramp like older lerobot EMA implementations).
warmup_steps: int = 0
# When True, the periodic eval block uses EMA weights (via a
# context-managed swap that restores the live weights on exit).
# Standard practice for diffusion-style policies — eval scores
# are usually 13% higher than the live policy at the same step.
# When True, the periodic eval block uses the EMA shadow model
# directly (``ema.ema_model``) instead of the live policy. Standard
# practice for diffusion-style policies — eval scores are usually
# 13% higher than the live policy at the same step.
use_for_eval: bool = True
# When True, the periodic wandb training-example dump uses EMA
# weights for the optional predicted-action columns (so what you
# see in W&B matches eval behavior).
# When True, the periodic wandb training-example dump uses the EMA
# shadow for the optional predicted-action columns (so what you see
# in W&B matches eval behavior).
use_for_wandb_examples: bool = True
+67 -57
View File
@@ -50,7 +50,6 @@ from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
from lerobot.rewards import make_reward_pre_post_processors
from lerobot.utils.collate import lerobot_collate_fn
from lerobot.utils.ema import ModelEMA
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
@@ -658,21 +657,31 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# ------------------------------------------------------------------
# EMA setup
# ------------------------------------------------------------------
# Maintain a shadow copy of the trainable params for late-training
# averaging (Chi et al. 2023 Diffusion Policy §V.D; standard for
# diffusion/flow-matching policies). Lives only on the main process —
# accelerator broadcasts updates to other ranks naturally via the
# live model. Off by default; opt in with ``--ema.enable=true``.
ema: ModelEMA | None = None
# Shadow copy of the trainable params for late-training averaging
# (Chi et al. 2023 Diffusion Policy §V.D; openpi JAX trainer ships
# this with decay=0.999 for pi05_libero; openpi PyTorch port and
# LeRobot main both skip it). Off by default; opt in with
# ``--ema.enable=true``. Implemented via ema-pytorch
# (https://github.com/lucidrains/ema-pytorch) — the standard PyTorch
# EMA library, also used by lucidrains' diffusion repos.
ema = None
if cfg.ema.enable and is_main_process:
ema = ModelEMA(
from ema_pytorch import EMA # noqa: PLC0415
ema = EMA(
accelerator.unwrap_model(policy),
decay=cfg.ema.decay,
warmup_steps=cfg.ema.warmup_steps,
beta=cfg.ema.decay,
update_after_step=cfg.ema.warmup_steps,
update_every=1, # update on every ema.update() call
# Don't register the live model as an ema submodule — accelerator
# already owns its lifecycle, and double-registration would
# double-count its params in ``ema.state_dict()``.
include_online_model=False,
)
ema.to(accelerator.device)
logging.info(
"EMA enabled: decay=%g, warmup_steps=%d, use_for_eval=%s, "
"use_for_wandb_examples=%s",
"EMA enabled (ema-pytorch): beta=%g, update_after_step=%d, "
"use_for_eval=%s, use_for_wandb_examples=%s",
cfg.ema.decay,
cfg.ema.warmup_steps,
cfg.ema.use_for_eval,
@@ -681,16 +690,11 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# Resume the EMA shadow if a previous run wrote one.
if cfg.checkpoint_path is not None:
ema_path = (
cfg.checkpoint_path / "training_state" / "ema_state.safetensors"
)
ema_path = cfg.checkpoint_path / "training_state" / "ema_state.pt"
if ema_path.exists():
logging.info("Resuming EMA shadow from %s", ema_path)
try:
ema = ModelEMA.load_from_file(
accelerator.unwrap_model(policy),
ema_path,
)
ema.load_state_dict(torch.load(ema_path, map_location=accelerator.device))
except Exception as exc: # noqa: BLE001
logging.warning(
"Failed to load EMA shadow (%s) — restarting EMA from "
@@ -753,11 +757,10 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# EMA update: pull one step of the live weights into the shadow.
# Runs only on the main process (the shadow lives there); other
# ranks rely on the live model staying in sync via accelerator.
# Returns the effective decay used (interesting during warmup).
# ``ema-pytorch`` holds an internal reference to the online model
# (set at construction), so ``ema.update()`` takes no args.
if ema is not None:
ema_effective_decay = ema.update(accelerator.unwrap_model(policy))
else:
ema_effective_decay = None
ema.update()
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here.
@@ -800,12 +803,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if sample_weighter is not None:
weighter_stats = sample_weighter.get_stats()
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
# EMA observability: the effective decay differs from
# cfg.ema.decay during warmup; ``num_updates`` lets you
# confirm the EMA is actually firing.
# EMA observability: ``ema.step`` is the count of
# ``ema.update()`` calls (= optimizer steps once EMA is
# enabled); ``ema.initted`` flips to True once we've
# crossed ``update_after_step``.
if ema is not None:
wandb_log_dict["ema/effective_decay"] = float(ema_effective_decay or ema.decay)
wandb_log_dict["ema/num_updates"] = int(ema.num_updates)
wandb_log_dict["ema/step"] = int(ema.step.item())
wandb_log_dict["ema/initted"] = float(ema.initted.item())
wandb_log_dict["ema/beta"] = float(cfg.ema.decay)
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
@@ -820,23 +825,24 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
and is_main_process
):
try:
# Optionally show the EMA model's predictions in the W&B
# examples table (matches the policy you'd actually
# deploy / eval with). Live weights are restored on exit.
wandb_ema_ctx = (
ema.apply_to(accelerator.unwrap_model(policy))
if ema is not None and cfg.ema.use_for_wandb_examples
else nullcontext()
# Optionally use the EMA shadow model directly for the
# predicted-action columns (matches what eval / deployment
# would see). ``ema-pytorch`` exposes the shadow as a
# full ``nn.Module`` at ``ema.ema_model``, so we just
# pass that instead of swap-and-restore.
target_policy = (
ema.ema_model
if (ema is not None and cfg.ema.use_for_wandb_examples)
else accelerator.unwrap_model(policy)
)
wandb_logger.log_training_examples(
batch=batch,
step=step,
camera_keys=list(dataset.meta.camera_keys),
n_samples=cfg.wandb.log_examples_n,
policy=target_policy,
predict_actions=cfg.wandb.log_examples_predict_actions,
)
with wandb_ema_ctx:
wandb_logger.log_training_examples(
batch=batch,
step=step,
camera_keys=list(dataset.meta.camera_keys),
n_samples=cfg.wandb.log_examples_n,
policy=accelerator.unwrap_model(policy),
predict_actions=cfg.wandb.log_examples_predict_actions,
)
except Exception as exc: # noqa: BLE001
logging.warning("wandb log_training_examples failed: %s", exc)
@@ -857,11 +863,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
update_last_checkpoint(checkpoint_dir)
# Save the EMA shadow alongside the training state so a
# resumed run picks up exactly where the live EMA left off.
# ``ema-pytorch.state_dict()`` returns the full shadow
# nn.Module's state dict + step/initted buffers; saved as
# .pt (the rest of training_state mixes formats already).
if ema is not None:
try:
ema.save(
checkpoint_dir / "training_state" / "ema_state.safetensors"
)
ema_path = checkpoint_dir / "training_state" / "ema_state.pt"
ema_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(ema.state_dict(), ema_path)
except Exception as exc: # noqa: BLE001
logging.warning("Failed to save EMA shadow: %s", exc)
if wandb_logger:
@@ -873,19 +882,20 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
if is_main_process:
step_id = get_step_identifier(step, cfg.steps)
logging.info(f"Eval policy at step {step}")
# Optionally swap in EMA weights for eval — standard
# practice for diffusion-style policies (~13% lift on
# closed-loop success). Live weights are restored byte-
# for-byte on context exit.
eval_ema_ctx = (
ema.apply_to(accelerator.unwrap_model(policy))
if ema is not None and cfg.ema.use_for_eval
else nullcontext()
# Use the EMA shadow model for eval when enabled —
# standard practice for diffusion-style policies (~13%
# lift on closed-loop success). ``ema.ema_model`` is a
# full nn.Module clone, so we just pass it through; no
# swap/restore on the live policy needed.
eval_target_policy = (
ema.ema_model
if (ema is not None and cfg.ema.use_for_eval)
else accelerator.unwrap_model(policy)
)
with eval_ema_ctx, torch.no_grad(), accelerator.autocast():
with torch.no_grad(), accelerator.autocast():
eval_info = eval_policy_all(
envs=eval_env, # dict[suite][task_id] -> vec_env
policy=accelerator.unwrap_model(policy),
policy=eval_target_policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
-250
View File
@@ -1,250 +0,0 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Exponential Moving Average of model parameters for training stability.
Maintains a shadow copy of every trainable parameter, updated after each
optimizer step::
θ_ema ← β · θ_ema + (1 - β) · θ_live
At eval / inference / final checkpoint, use ``θ_ema`` instead of
``θ_live``. For diffusion / flow-matching policies, averaging late-
training oscillations yields a smoother model that generalises
substantially better at inference — see Chi et al. 2023 (Diffusion
Policy §V.D, β=0.75), Ho et al. 2020 (DDPM appendix). For VLAs with a
flow-matching action expert the same logic applies: flow gradients have
high variance per sample (different noise levels in the same batch),
so EMA smooths over that variance.
Cost: 1× model parameters in fp32 (~13 GB for pi052's 3.3B params),
plus one elementwise update per training step (~1% of step time).
"""
from __future__ import annotations
import json
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Iterator
import torch
import torch.nn as nn
from safetensors.torch import load_file, save_file
class ModelEMA:
"""Exponential moving average of trainable model parameters.
Args:
model: The live model whose parameter shapes/names define the
shadow. Only parameters with ``requires_grad=True`` are
tracked. Buffers are intentionally NOT tracked (LayerNorm
running stats, RoPE caches, etc.) — they are updated in
``train()`` mode regardless of which weights we apply.
decay: Target EMA decay (``β`` in
``θ_ema ← β·θ_ema + (1-β)·θ_live``). Typical values:
* ``0.999`` — averages roughly the last 1000 steps. Standard
for diffusion-style policies and the default here.
* ``0.75`` — very fast EMA, used by Diffusion Policy (Chi
et al. 2023). Useful when training is short or noisy.
* ``0.9999`` — very slow EMA, used in image classification
for very long runs.
warmup_steps: If > 0, ramp the effective decay from a low value
up to ``decay`` over the first ``warmup_steps`` updates as
``min(decay, (1 + n) / (10 + n))``. Lets the EMA track
rapid early-training changes before settling on the target.
device: Where to keep the shadow parameters. ``None`` keeps each
shadow on its parameter's device (good for FSDP / multi-GPU).
Pass an explicit device to relocate everything (e.g. ``"cpu"``
to free GPU memory at the cost of slower updates).
"""
def __init__(
self,
model: nn.Module,
*,
decay: float = 0.999,
warmup_steps: int = 0,
device: torch.device | str | None = None,
) -> None:
if not 0.0 < decay < 1.0:
raise ValueError(f"decay must be in (0, 1), got {decay}")
self.decay = float(decay)
self.warmup_steps = int(warmup_steps)
self.num_updates = 0
self.device = torch.device(device) if device is not None else None
# fp32 shadow — small EMA updates lose precision in bf16.
self.shadow: dict[str, torch.Tensor] = {}
for name, p in model.named_parameters():
if not p.requires_grad:
continue
shadow_p = p.detach().clone().float()
if self.device is not None:
shadow_p = shadow_p.to(self.device)
self.shadow[name] = shadow_p
# ------------------------------------------------------------------
# Core update
# ------------------------------------------------------------------
def _effective_decay(self) -> float:
if self.warmup_steps <= 0 or self.num_updates >= self.warmup_steps:
return self.decay
# Standard EMA warmup (timm / diffusers convention): grows
# 0.09, 0.16, 0.23, ... and saturates at ``decay``.
return min(self.decay, (1.0 + self.num_updates) / (10.0 + self.num_updates))
@torch.no_grad()
def update(self, model: nn.Module) -> float:
"""Pull one update from the live model into the shadow.
Returns the effective decay used this step (useful to log during
warmup, when the value differs from ``self.decay``).
"""
self.num_updates += 1
beta = self._effective_decay()
one_minus_beta = 1.0 - beta
for name, p in model.named_parameters():
if not p.requires_grad:
continue
shadow = self.shadow.get(name)
if shadow is None:
# New parameter appeared mid-training — seed it.
shadow = p.detach().clone().float()
if self.device is not None:
shadow = shadow.to(self.device)
self.shadow[name] = shadow
continue
# In-place fused: shadow ← β · shadow + (1 - β) · p
shadow.mul_(beta).add_(
p.detach().to(shadow.device, dtype=torch.float32),
alpha=one_minus_beta,
)
return beta
# ------------------------------------------------------------------
# Applying the EMA to the live model
# ------------------------------------------------------------------
@torch.no_grad()
def copy_to(self, model: nn.Module) -> None:
"""Overwrite the live model's parameters with the EMA shadow.
In-place and **irreversible** — the previous live weights are
lost. Use this only at the very end of training when you want
the EMA to *be* the final saved policy. For temporary swaps
(e.g. during eval), use :meth:`apply_to`.
"""
for name, p in model.named_parameters():
shadow = self.shadow.get(name)
if shadow is not None:
p.data.copy_(shadow.to(p.device, dtype=p.dtype))
@contextmanager
def apply_to(self, model: nn.Module) -> Iterator[None]:
"""Temporarily swap the live model's weights with the EMA copy.
On exit, the original live weights are restored byte-for-byte
(we keep a backup clone of every tracked parameter inside the
context). Use this around eval / sample-logging without
disturbing the live training state::
with ema.apply_to(policy):
eval_metrics = evaluate(policy)
# policy is back to its pre-eval state here.
"""
backup: dict[str, torch.Tensor] = {}
for name, p in model.named_parameters():
if name in self.shadow:
backup[name] = p.detach().clone()
p.data.copy_(self.shadow[name].to(p.device, dtype=p.dtype))
try:
yield
finally:
for name, p in model.named_parameters():
if name in backup:
p.data.copy_(backup[name].to(p.device, dtype=p.dtype))
# ------------------------------------------------------------------
# Checkpointing
# ------------------------------------------------------------------
def state_dict(self) -> dict[str, Any]:
return {
"shadow": {k: v.detach().cpu() for k, v in self.shadow.items()},
"num_updates": self.num_updates,
"decay": self.decay,
"warmup_steps": self.warmup_steps,
}
def load_state_dict(self, state: dict[str, Any]) -> None:
self.decay = float(state["decay"])
self.warmup_steps = int(state["warmup_steps"])
self.num_updates = int(state["num_updates"])
new_shadow: dict[str, torch.Tensor] = {}
for k, v in state["shadow"].items():
t = v.detach()
if self.device is not None:
t = t.to(self.device)
new_shadow[k] = t.float()
self.shadow = new_shadow
def save(self, path: Path | str) -> None:
"""Save the shadow as safetensors + a tiny JSON sidecar with metadata.
Sidecar lives at ``<path>.json`` and stores ``num_updates``,
``decay``, ``warmup_steps`` — enough to resume exact EMA state.
"""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
save_file(
{k: v.detach().cpu().contiguous() for k, v in self.shadow.items()},
str(path),
)
meta = {
"num_updates": self.num_updates,
"decay": self.decay,
"warmup_steps": self.warmup_steps,
}
path.with_suffix(path.suffix + ".json").write_text(json.dumps(meta, indent=2))
@classmethod
def load_from_file(
cls,
model: nn.Module,
path: Path | str,
*,
device: torch.device | str | None = None,
) -> "ModelEMA":
"""Reconstruct a ``ModelEMA`` from a previously-saved safetensors + sidecar pair."""
path = Path(path)
meta_path = path.with_suffix(path.suffix + ".json")
meta = json.loads(meta_path.read_text()) if meta_path.exists() else {}
ema = cls(
model,
decay=float(meta.get("decay", 0.999)),
warmup_steps=int(meta.get("warmup_steps", 0)),
device=device,
)
shadow = load_file(str(path))
target_device = ema.device
ema.shadow = {
k: (v.to(target_device) if target_device is not None else v).float()
for k, v in shadow.items()
}
ema.num_updates = int(meta.get("num_updates", 0))
return ema