mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 06:59:44 +00:00
train: switch EMA from custom ModelEMA to ema-pytorch
Replace the 250-line src/lerobot/utils/ema.py with a direct dependency
on ema-pytorch (lucidrains' canonical PyTorch EMA library). Same
semantics, decay=0.999 default unchanged, but offloads the maintenance
burden to a maintained library used by every diffusion repo.
Why ema-pytorch:
* Standard PyTorch EMA library; battle-tested across diffusion +
speech + image-gen codebases.
* Tiny pure-python dep (no compiled code).
* Cleaner consumer-side API: ema.ema_model is a full nn.Module
clone of the policy, so eval / wandb just pass it through instead
of context-managed swap/restore on the live model.
What changed mechanically:
* pyproject.toml: add 'ema-pytorch>=0.7.7,<1.0.0' to core deps.
* deleted src/lerobot/utils/ema.py (the custom ModelEMA).
* scripts/lerobot_train.py:
- import EMA from ema_pytorch
- instantiate with beta=cfg.ema.decay,
update_after_step=cfg.ema.warmup_steps, update_every=1,
include_online_model=False (accelerator owns live model
lifecycle; double-registration would double-count params).
- ema.update() (no args) — library tracks the online model
internally.
- Eval block: pass eval_target_policy = ema.ema_model (when
cfg.ema.use_for_eval) instead of swap context manager.
- W&B examples: same pattern.
- Save: torch.save(ema.state_dict(), .../ema_state.pt) instead
of custom safetensors writer. .pt format is consistent with
the rest of training_state which already mixes safetensors +
json + (now) pt.
- Resume: ema.load_state_dict(torch.load(.../ema_state.pt)).
- WandB observability: ema/step (count of ema.update calls),
ema/initted (bool from library), ema/beta (constant from
cfg).
* configs/default.py: EMAConfig.decay stays 0.999 (matches
openpi's pi05_libero); docstring updated to reflect ema-pytrch
semantics for warmup_steps (now maps to update_after_step — a hard
skip, not a smooth decay ramp).
Behavior preserved:
* Defaults: enable=False, decay=0.999, warmup_steps=0,
use_for_eval=True, use_for_wandb_examples=True.
* Same CLI: --ema.enable=true, --ema.decay=X, etc.
* Same checkpoint layout (training_state/ema_state.pt next to
optimizer_state.safetensors etc.); resumes silently if present.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -85,6 +85,11 @@ dependencies = [
|
||||
"termcolor>=2.4.0,<4.0.0",
|
||||
"tqdm>=4.66.0,<5.0.0",
|
||||
|
||||
# Training utilities
|
||||
# EMA of policy parameters (Diffusion Policy / pi05 style). Tiny
|
||||
# pure-python dependency — preferred over a hand-rolled implementation.
|
||||
"ema-pytorch>=0.7.7,<1.0.0",
|
||||
|
||||
# Build tools (required by opencv-python-headless on some platforms)
|
||||
"cmake>=3.29.0.1,<4.2.0",
|
||||
"setuptools>=71.0.0,<81.0.0",
|
||||
|
||||
@@ -87,35 +87,44 @@ class EMAConfig:
|
||||
|
||||
Diffusion / flow-matching policies (Diffusion Policy, π0/π0.5,
|
||||
pi052) benefit substantially from averaging late-training
|
||||
parameter oscillations — see Chi et al. 2023 §V.D. EMA adds a
|
||||
fp32 shadow of the policy (~13 GB for pi052's 3.3B params) and one
|
||||
elementwise update per training step (~1% step time).
|
||||
parameter oscillations — see Chi et al. 2023 §V.D. The official
|
||||
JAX openpi trainer ships EMA with ``ema_decay=0.99`` (default) and
|
||||
``0.999`` for its pi05_libero config; the openpi PyTorch port
|
||||
explicitly lists EMA as unsupported, and LeRobot main inherited
|
||||
that gap. Enabling this flag plugs ema-pytorch
|
||||
(https://github.com/lucidrains/ema-pytorch) into the LeRobot
|
||||
training loop with a shadow ``nn.Module`` clone of the policy.
|
||||
|
||||
Off by default (back-compat for existing runs). Recommended for
|
||||
long pi052 training runs where you want closed-loop eval to use
|
||||
the smoothed weights — typically ~1–3% absolute success-rate
|
||||
improvement on closed-loop tasks per the diffusion-policy lit.
|
||||
Cost: 1× model params in fp32 shadow (~13 GB for pi052's 3.3B
|
||||
params) + one elementwise update per training step (~1% step time).
|
||||
|
||||
Off by default (back-compat). Recommended for long pi052 training
|
||||
runs — typically ~1–3% absolute success-rate improvement on
|
||||
closed-loop tasks per the diffusion-policy literature.
|
||||
"""
|
||||
|
||||
enable: bool = False
|
||||
# Target EMA decay β in θ_ema ← β·θ_ema + (1-β)·θ_live.
|
||||
# 0.999 ≈ last 1000 steps (standard; pi05-class default)
|
||||
# 0.75 ≈ very fast EMA (Diffusion Policy original setting)
|
||||
# 0.9999 ≈ very slow EMA (long classification runs)
|
||||
# Target EMA decay β in θ_ema ← β·θ_ema + (1-β)·θ_live (passed to
|
||||
# ema-pytorch as ``beta``).
|
||||
# 0.999 — last ~1000 steps; pi05_libero default in openpi
|
||||
# 0.99 — last ~100 steps; openpi top-level default
|
||||
# 0.75 — very fast EMA (Diffusion Policy original setting)
|
||||
# 0.9999 — very slow EMA (long classification runs)
|
||||
decay: float = 0.999
|
||||
# If > 0, ramp effective decay up to ``decay`` over the first N
|
||||
# updates as min(decay, (1+n)/(10+n)). Lets the EMA track rapid
|
||||
# early-training changes before settling. ``0`` = use ``decay``
|
||||
# from step 1.
|
||||
# Skip the first N calls to ``ema.update()``; during this window
|
||||
# the shadow is just a hard copy of the live weights (no averaging).
|
||||
# Lets early-training rapid changes settle before averaging begins.
|
||||
# Maps to ema-pytorch's ``update_after_step`` (NOT a smooth decay
|
||||
# ramp like older lerobot EMA implementations).
|
||||
warmup_steps: int = 0
|
||||
# When True, the periodic eval block uses EMA weights (via a
|
||||
# context-managed swap that restores the live weights on exit).
|
||||
# Standard practice for diffusion-style policies — eval scores
|
||||
# are usually 1–3% higher than the live policy at the same step.
|
||||
# When True, the periodic eval block uses the EMA shadow model
|
||||
# directly (``ema.ema_model``) instead of the live policy. Standard
|
||||
# practice for diffusion-style policies — eval scores are usually
|
||||
# 1–3% higher than the live policy at the same step.
|
||||
use_for_eval: bool = True
|
||||
# When True, the periodic wandb training-example dump uses EMA
|
||||
# weights for the optional predicted-action columns (so what you
|
||||
# see in W&B matches eval behavior).
|
||||
# When True, the periodic wandb training-example dump uses the EMA
|
||||
# shadow for the optional predicted-action columns (so what you see
|
||||
# in W&B matches eval behavior).
|
||||
use_for_wandb_examples: bool = True
|
||||
|
||||
|
||||
|
||||
@@ -50,7 +50,6 @@ from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.rewards import make_reward_pre_post_processors
|
||||
from lerobot.utils.collate import lerobot_collate_fn
|
||||
from lerobot.utils.ema import ModelEMA
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
@@ -658,21 +657,31 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
# ------------------------------------------------------------------
|
||||
# EMA setup
|
||||
# ------------------------------------------------------------------
|
||||
# Maintain a shadow copy of the trainable params for late-training
|
||||
# averaging (Chi et al. 2023 Diffusion Policy §V.D; standard for
|
||||
# diffusion/flow-matching policies). Lives only on the main process —
|
||||
# accelerator broadcasts updates to other ranks naturally via the
|
||||
# live model. Off by default; opt in with ``--ema.enable=true``.
|
||||
ema: ModelEMA | None = None
|
||||
# Shadow copy of the trainable params for late-training averaging
|
||||
# (Chi et al. 2023 Diffusion Policy §V.D; openpi JAX trainer ships
|
||||
# this with decay=0.999 for pi05_libero; openpi PyTorch port and
|
||||
# LeRobot main both skip it). Off by default; opt in with
|
||||
# ``--ema.enable=true``. Implemented via ema-pytorch
|
||||
# (https://github.com/lucidrains/ema-pytorch) — the standard PyTorch
|
||||
# EMA library, also used by lucidrains' diffusion repos.
|
||||
ema = None
|
||||
if cfg.ema.enable and is_main_process:
|
||||
ema = ModelEMA(
|
||||
from ema_pytorch import EMA # noqa: PLC0415
|
||||
|
||||
ema = EMA(
|
||||
accelerator.unwrap_model(policy),
|
||||
decay=cfg.ema.decay,
|
||||
warmup_steps=cfg.ema.warmup_steps,
|
||||
beta=cfg.ema.decay,
|
||||
update_after_step=cfg.ema.warmup_steps,
|
||||
update_every=1, # update on every ema.update() call
|
||||
# Don't register the live model as an ema submodule — accelerator
|
||||
# already owns its lifecycle, and double-registration would
|
||||
# double-count its params in ``ema.state_dict()``.
|
||||
include_online_model=False,
|
||||
)
|
||||
ema.to(accelerator.device)
|
||||
logging.info(
|
||||
"EMA enabled: decay=%g, warmup_steps=%d, use_for_eval=%s, "
|
||||
"use_for_wandb_examples=%s",
|
||||
"EMA enabled (ema-pytorch): beta=%g, update_after_step=%d, "
|
||||
"use_for_eval=%s, use_for_wandb_examples=%s",
|
||||
cfg.ema.decay,
|
||||
cfg.ema.warmup_steps,
|
||||
cfg.ema.use_for_eval,
|
||||
@@ -681,16 +690,11 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
# Resume the EMA shadow if a previous run wrote one.
|
||||
if cfg.checkpoint_path is not None:
|
||||
ema_path = (
|
||||
cfg.checkpoint_path / "training_state" / "ema_state.safetensors"
|
||||
)
|
||||
ema_path = cfg.checkpoint_path / "training_state" / "ema_state.pt"
|
||||
if ema_path.exists():
|
||||
logging.info("Resuming EMA shadow from %s", ema_path)
|
||||
try:
|
||||
ema = ModelEMA.load_from_file(
|
||||
accelerator.unwrap_model(policy),
|
||||
ema_path,
|
||||
)
|
||||
ema.load_state_dict(torch.load(ema_path, map_location=accelerator.device))
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning(
|
||||
"Failed to load EMA shadow (%s) — restarting EMA from "
|
||||
@@ -753,11 +757,10 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
# EMA update: pull one step of the live weights into the shadow.
|
||||
# Runs only on the main process (the shadow lives there); other
|
||||
# ranks rely on the live model staying in sync via accelerator.
|
||||
# Returns the effective decay used (interesting during warmup).
|
||||
# ``ema-pytorch`` holds an internal reference to the online model
|
||||
# (set at construction), so ``ema.update()`` takes no args.
|
||||
if ema is not None:
|
||||
ema_effective_decay = ema.update(accelerator.unwrap_model(policy))
|
||||
else:
|
||||
ema_effective_decay = None
|
||||
ema.update()
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
# increment `step` here.
|
||||
@@ -800,12 +803,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if sample_weighter is not None:
|
||||
weighter_stats = sample_weighter.get_stats()
|
||||
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
||||
# EMA observability: the effective decay differs from
|
||||
# cfg.ema.decay during warmup; ``num_updates`` lets you
|
||||
# confirm the EMA is actually firing.
|
||||
# EMA observability: ``ema.step`` is the count of
|
||||
# ``ema.update()`` calls (= optimizer steps once EMA is
|
||||
# enabled); ``ema.initted`` flips to True once we've
|
||||
# crossed ``update_after_step``.
|
||||
if ema is not None:
|
||||
wandb_log_dict["ema/effective_decay"] = float(ema_effective_decay or ema.decay)
|
||||
wandb_log_dict["ema/num_updates"] = int(ema.num_updates)
|
||||
wandb_log_dict["ema/step"] = int(ema.step.item())
|
||||
wandb_log_dict["ema/initted"] = float(ema.initted.item())
|
||||
wandb_log_dict["ema/beta"] = float(cfg.ema.decay)
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
@@ -820,23 +825,24 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
and is_main_process
|
||||
):
|
||||
try:
|
||||
# Optionally show the EMA model's predictions in the W&B
|
||||
# examples table (matches the policy you'd actually
|
||||
# deploy / eval with). Live weights are restored on exit.
|
||||
wandb_ema_ctx = (
|
||||
ema.apply_to(accelerator.unwrap_model(policy))
|
||||
if ema is not None and cfg.ema.use_for_wandb_examples
|
||||
else nullcontext()
|
||||
# Optionally use the EMA shadow model directly for the
|
||||
# predicted-action columns (matches what eval / deployment
|
||||
# would see). ``ema-pytorch`` exposes the shadow as a
|
||||
# full ``nn.Module`` at ``ema.ema_model``, so we just
|
||||
# pass that instead of swap-and-restore.
|
||||
target_policy = (
|
||||
ema.ema_model
|
||||
if (ema is not None and cfg.ema.use_for_wandb_examples)
|
||||
else accelerator.unwrap_model(policy)
|
||||
)
|
||||
wandb_logger.log_training_examples(
|
||||
batch=batch,
|
||||
step=step,
|
||||
camera_keys=list(dataset.meta.camera_keys),
|
||||
n_samples=cfg.wandb.log_examples_n,
|
||||
policy=target_policy,
|
||||
predict_actions=cfg.wandb.log_examples_predict_actions,
|
||||
)
|
||||
with wandb_ema_ctx:
|
||||
wandb_logger.log_training_examples(
|
||||
batch=batch,
|
||||
step=step,
|
||||
camera_keys=list(dataset.meta.camera_keys),
|
||||
n_samples=cfg.wandb.log_examples_n,
|
||||
policy=accelerator.unwrap_model(policy),
|
||||
predict_actions=cfg.wandb.log_examples_predict_actions,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning("wandb log_training_examples failed: %s", exc)
|
||||
|
||||
@@ -857,11 +863,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
# Save the EMA shadow alongside the training state so a
|
||||
# resumed run picks up exactly where the live EMA left off.
|
||||
# ``ema-pytorch.state_dict()`` returns the full shadow
|
||||
# nn.Module's state dict + step/initted buffers; saved as
|
||||
# .pt (the rest of training_state mixes formats already).
|
||||
if ema is not None:
|
||||
try:
|
||||
ema.save(
|
||||
checkpoint_dir / "training_state" / "ema_state.safetensors"
|
||||
)
|
||||
ema_path = checkpoint_dir / "training_state" / "ema_state.pt"
|
||||
ema_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
torch.save(ema.state_dict(), ema_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning("Failed to save EMA shadow: %s", exc)
|
||||
if wandb_logger:
|
||||
@@ -873,19 +882,20 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if is_main_process:
|
||||
step_id = get_step_identifier(step, cfg.steps)
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
# Optionally swap in EMA weights for eval — standard
|
||||
# practice for diffusion-style policies (~1–3% lift on
|
||||
# closed-loop success). Live weights are restored byte-
|
||||
# for-byte on context exit.
|
||||
eval_ema_ctx = (
|
||||
ema.apply_to(accelerator.unwrap_model(policy))
|
||||
if ema is not None and cfg.ema.use_for_eval
|
||||
else nullcontext()
|
||||
# Use the EMA shadow model for eval when enabled —
|
||||
# standard practice for diffusion-style policies (~1–3%
|
||||
# lift on closed-loop success). ``ema.ema_model`` is a
|
||||
# full nn.Module clone, so we just pass it through; no
|
||||
# swap/restore on the live policy needed.
|
||||
eval_target_policy = (
|
||||
ema.ema_model
|
||||
if (ema is not None and cfg.ema.use_for_eval)
|
||||
else accelerator.unwrap_model(policy)
|
||||
)
|
||||
with eval_ema_ctx, torch.no_grad(), accelerator.autocast():
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
eval_info = eval_policy_all(
|
||||
envs=eval_env, # dict[suite][task_id] -> vec_env
|
||||
policy=accelerator.unwrap_model(policy),
|
||||
policy=eval_target_policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
|
||||
@@ -1,250 +0,0 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Exponential Moving Average of model parameters for training stability.
|
||||
|
||||
Maintains a shadow copy of every trainable parameter, updated after each
|
||||
optimizer step::
|
||||
|
||||
θ_ema ← β · θ_ema + (1 - β) · θ_live
|
||||
|
||||
At eval / inference / final checkpoint, use ``θ_ema`` instead of
|
||||
``θ_live``. For diffusion / flow-matching policies, averaging late-
|
||||
training oscillations yields a smoother model that generalises
|
||||
substantially better at inference — see Chi et al. 2023 (Diffusion
|
||||
Policy §V.D, β=0.75), Ho et al. 2020 (DDPM appendix). For VLAs with a
|
||||
flow-matching action expert the same logic applies: flow gradients have
|
||||
high variance per sample (different noise levels in the same batch),
|
||||
so EMA smooths over that variance.
|
||||
|
||||
Cost: 1× model parameters in fp32 (~13 GB for pi052's 3.3B params),
|
||||
plus one elementwise update per training step (~1% of step time).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterator
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
|
||||
class ModelEMA:
|
||||
"""Exponential moving average of trainable model parameters.
|
||||
|
||||
Args:
|
||||
model: The live model whose parameter shapes/names define the
|
||||
shadow. Only parameters with ``requires_grad=True`` are
|
||||
tracked. Buffers are intentionally NOT tracked (LayerNorm
|
||||
running stats, RoPE caches, etc.) — they are updated in
|
||||
``train()`` mode regardless of which weights we apply.
|
||||
decay: Target EMA decay (``β`` in
|
||||
``θ_ema ← β·θ_ema + (1-β)·θ_live``). Typical values:
|
||||
|
||||
* ``0.999`` — averages roughly the last 1000 steps. Standard
|
||||
for diffusion-style policies and the default here.
|
||||
* ``0.75`` — very fast EMA, used by Diffusion Policy (Chi
|
||||
et al. 2023). Useful when training is short or noisy.
|
||||
* ``0.9999`` — very slow EMA, used in image classification
|
||||
for very long runs.
|
||||
warmup_steps: If > 0, ramp the effective decay from a low value
|
||||
up to ``decay`` over the first ``warmup_steps`` updates as
|
||||
``min(decay, (1 + n) / (10 + n))``. Lets the EMA track
|
||||
rapid early-training changes before settling on the target.
|
||||
device: Where to keep the shadow parameters. ``None`` keeps each
|
||||
shadow on its parameter's device (good for FSDP / multi-GPU).
|
||||
Pass an explicit device to relocate everything (e.g. ``"cpu"``
|
||||
to free GPU memory at the cost of slower updates).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
*,
|
||||
decay: float = 0.999,
|
||||
warmup_steps: int = 0,
|
||||
device: torch.device | str | None = None,
|
||||
) -> None:
|
||||
if not 0.0 < decay < 1.0:
|
||||
raise ValueError(f"decay must be in (0, 1), got {decay}")
|
||||
self.decay = float(decay)
|
||||
self.warmup_steps = int(warmup_steps)
|
||||
self.num_updates = 0
|
||||
self.device = torch.device(device) if device is not None else None
|
||||
|
||||
# fp32 shadow — small EMA updates lose precision in bf16.
|
||||
self.shadow: dict[str, torch.Tensor] = {}
|
||||
for name, p in model.named_parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
shadow_p = p.detach().clone().float()
|
||||
if self.device is not None:
|
||||
shadow_p = shadow_p.to(self.device)
|
||||
self.shadow[name] = shadow_p
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core update
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _effective_decay(self) -> float:
|
||||
if self.warmup_steps <= 0 or self.num_updates >= self.warmup_steps:
|
||||
return self.decay
|
||||
# Standard EMA warmup (timm / diffusers convention): grows
|
||||
# 0.09, 0.16, 0.23, ... and saturates at ``decay``.
|
||||
return min(self.decay, (1.0 + self.num_updates) / (10.0 + self.num_updates))
|
||||
|
||||
@torch.no_grad()
|
||||
def update(self, model: nn.Module) -> float:
|
||||
"""Pull one update from the live model into the shadow.
|
||||
|
||||
Returns the effective decay used this step (useful to log during
|
||||
warmup, when the value differs from ``self.decay``).
|
||||
"""
|
||||
self.num_updates += 1
|
||||
beta = self._effective_decay()
|
||||
one_minus_beta = 1.0 - beta
|
||||
for name, p in model.named_parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
shadow = self.shadow.get(name)
|
||||
if shadow is None:
|
||||
# New parameter appeared mid-training — seed it.
|
||||
shadow = p.detach().clone().float()
|
||||
if self.device is not None:
|
||||
shadow = shadow.to(self.device)
|
||||
self.shadow[name] = shadow
|
||||
continue
|
||||
# In-place fused: shadow ← β · shadow + (1 - β) · p
|
||||
shadow.mul_(beta).add_(
|
||||
p.detach().to(shadow.device, dtype=torch.float32),
|
||||
alpha=one_minus_beta,
|
||||
)
|
||||
return beta
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Applying the EMA to the live model
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@torch.no_grad()
|
||||
def copy_to(self, model: nn.Module) -> None:
|
||||
"""Overwrite the live model's parameters with the EMA shadow.
|
||||
|
||||
In-place and **irreversible** — the previous live weights are
|
||||
lost. Use this only at the very end of training when you want
|
||||
the EMA to *be* the final saved policy. For temporary swaps
|
||||
(e.g. during eval), use :meth:`apply_to`.
|
||||
"""
|
||||
for name, p in model.named_parameters():
|
||||
shadow = self.shadow.get(name)
|
||||
if shadow is not None:
|
||||
p.data.copy_(shadow.to(p.device, dtype=p.dtype))
|
||||
|
||||
@contextmanager
|
||||
def apply_to(self, model: nn.Module) -> Iterator[None]:
|
||||
"""Temporarily swap the live model's weights with the EMA copy.
|
||||
|
||||
On exit, the original live weights are restored byte-for-byte
|
||||
(we keep a backup clone of every tracked parameter inside the
|
||||
context). Use this around eval / sample-logging without
|
||||
disturbing the live training state::
|
||||
|
||||
with ema.apply_to(policy):
|
||||
eval_metrics = evaluate(policy)
|
||||
# policy is back to its pre-eval state here.
|
||||
"""
|
||||
backup: dict[str, torch.Tensor] = {}
|
||||
for name, p in model.named_parameters():
|
||||
if name in self.shadow:
|
||||
backup[name] = p.detach().clone()
|
||||
p.data.copy_(self.shadow[name].to(p.device, dtype=p.dtype))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for name, p in model.named_parameters():
|
||||
if name in backup:
|
||||
p.data.copy_(backup[name].to(p.device, dtype=p.dtype))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Checkpointing
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"shadow": {k: v.detach().cpu() for k, v in self.shadow.items()},
|
||||
"num_updates": self.num_updates,
|
||||
"decay": self.decay,
|
||||
"warmup_steps": self.warmup_steps,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state: dict[str, Any]) -> None:
|
||||
self.decay = float(state["decay"])
|
||||
self.warmup_steps = int(state["warmup_steps"])
|
||||
self.num_updates = int(state["num_updates"])
|
||||
new_shadow: dict[str, torch.Tensor] = {}
|
||||
for k, v in state["shadow"].items():
|
||||
t = v.detach()
|
||||
if self.device is not None:
|
||||
t = t.to(self.device)
|
||||
new_shadow[k] = t.float()
|
||||
self.shadow = new_shadow
|
||||
|
||||
def save(self, path: Path | str) -> None:
|
||||
"""Save the shadow as safetensors + a tiny JSON sidecar with metadata.
|
||||
|
||||
Sidecar lives at ``<path>.json`` and stores ``num_updates``,
|
||||
``decay``, ``warmup_steps`` — enough to resume exact EMA state.
|
||||
"""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
save_file(
|
||||
{k: v.detach().cpu().contiguous() for k, v in self.shadow.items()},
|
||||
str(path),
|
||||
)
|
||||
meta = {
|
||||
"num_updates": self.num_updates,
|
||||
"decay": self.decay,
|
||||
"warmup_steps": self.warmup_steps,
|
||||
}
|
||||
path.with_suffix(path.suffix + ".json").write_text(json.dumps(meta, indent=2))
|
||||
|
||||
@classmethod
|
||||
def load_from_file(
|
||||
cls,
|
||||
model: nn.Module,
|
||||
path: Path | str,
|
||||
*,
|
||||
device: torch.device | str | None = None,
|
||||
) -> "ModelEMA":
|
||||
"""Reconstruct a ``ModelEMA`` from a previously-saved safetensors + sidecar pair."""
|
||||
path = Path(path)
|
||||
meta_path = path.with_suffix(path.suffix + ".json")
|
||||
meta = json.loads(meta_path.read_text()) if meta_path.exists() else {}
|
||||
ema = cls(
|
||||
model,
|
||||
decay=float(meta.get("decay", 0.999)),
|
||||
warmup_steps=int(meta.get("warmup_steps", 0)),
|
||||
device=device,
|
||||
)
|
||||
shadow = load_file(str(path))
|
||||
target_device = ema.device
|
||||
ema.shadow = {
|
||||
k: (v.to(target_device) if target_device is not None else v).float()
|
||||
for k, v in shadow.items()
|
||||
}
|
||||
ema.num_updates = int(meta.get("num_updates", 0))
|
||||
return ema
|
||||
Reference in New Issue
Block a user