mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 22:49:48 +00:00
train: EMA of policy parameters (opt-in via --ema.enable=true)
Adds Exponential Moving Average of trainable policy parameters with
warmup, eval-time swap, checkpoint save/resume, and wandb observability.
For diffusion / flow-matching policies (pi052's flow expert exactly
qualifies), averaging late-training parameter oscillations yields a
smoother model that generalises substantially better at inference —
~1–3% absolute success-rate improvement on closed-loop tasks per the
diffusion-policy lit (Chi et al. 2023 §V.D; standard in DDPM/EDM).
New module: src/lerobot/utils/ema.py
ModelEMA class with:
* fp32 shadow of every requires_grad parameter
* decay warmup: min(decay, (1+n)/(10+n)) for first warmup_steps updates
* update(model) -> effective_decay (for logging)
* apply_to(model) context manager: temp-swap weights, restore on exit
* copy_to(model): permanent overwrite
* save() / load_from_file(): safetensors + JSON sidecar for metadata
* state_dict() / load_state_dict() for in-process round-tripping
New config: src/lerobot/configs/default.py EMAConfig + wired into
TrainPipelineConfig as 'ema: EMAConfig'.
Fields:
enable: bool = False (off by default, back-compat)
decay: float = 0.999 (standard; 0.75 for fast Diffusion-Policy)
warmup_steps: int = 0 (no warmup by default)
use_for_eval: bool = True (eval swaps in EMA weights)
use_for_wandb_examples: bool = True
(W&B training-examples table uses EMA
for predicted-action columns -> matches
what eval / deployment would see)
Training loop integration (src/lerobot/scripts/lerobot_train.py):
1. After accelerator.prepare + policy.train(), instantiate ModelEMA
on the main process if cfg.ema.enable. Resume from
checkpoint_path/training_state/ema_state.safetensors if present.
2. After each update_policy() call, ema.update(unwrap_model(policy))
returns the effective decay (logged to wandb during warmup).
3. The save_checkpoint() block also ema.save(...) the shadow next to
the existing optimizer/scheduler/rng training state. Resume picks
it up automatically in (1).
4. The eval block (cfg.env && is_eval_step) wraps eval_policy_all in
ema.apply_to() when use_for_eval=True. Live weights restored
byte-for-byte on context exit.
5. The W&B training-example dump wraps log_training_examples in
ema.apply_to() when use_for_wandb_examples=True so the predicted-
action columns match the eval/deployment behavior.
6. Two new wandb scalars: ema/effective_decay, ema/num_updates.
Cost:
Memory: 1x model params in fp32 (~13 GB for pi052's 3.3B params).
Lives only on main-process GPU. CPU offload available via
ModelEMA(device='cpu') if needed.
Compute: one elementwise update per step (~1% of step time).
Eval: 2x checkpoint files in training_state/ (live optimizer state
+ ema shadow). Negligible relative to model.safetensors.
Usage:
lerobot-train ... --ema.enable=true
lerobot-train ... --ema.enable=true --ema.decay=0.9999 # very slow EMA
lerobot-train ... --ema.enable=true --ema.warmup_steps=1000
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -81,6 +81,44 @@ class WandBConfig:
|
|||||||
log_examples_predict_actions: bool = True
|
log_examples_predict_actions: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EMAConfig:
|
||||||
|
"""Exponential Moving Average of trainable policy parameters.
|
||||||
|
|
||||||
|
Diffusion / flow-matching policies (Diffusion Policy, π0/π0.5,
|
||||||
|
pi052) benefit substantially from averaging late-training
|
||||||
|
parameter oscillations — see Chi et al. 2023 §V.D. EMA adds a
|
||||||
|
fp32 shadow of the policy (~13 GB for pi052's 3.3B params) and one
|
||||||
|
elementwise update per training step (~1% step time).
|
||||||
|
|
||||||
|
Off by default (back-compat for existing runs). Recommended for
|
||||||
|
long pi052 training runs where you want closed-loop eval to use
|
||||||
|
the smoothed weights — typically ~1–3% absolute success-rate
|
||||||
|
improvement on closed-loop tasks per the diffusion-policy lit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
enable: bool = False
|
||||||
|
# Target EMA decay β in θ_ema ← β·θ_ema + (1-β)·θ_live.
|
||||||
|
# 0.999 ≈ last 1000 steps (standard; pi05-class default)
|
||||||
|
# 0.75 ≈ very fast EMA (Diffusion Policy original setting)
|
||||||
|
# 0.9999 ≈ very slow EMA (long classification runs)
|
||||||
|
decay: float = 0.999
|
||||||
|
# If > 0, ramp effective decay up to ``decay`` over the first N
|
||||||
|
# updates as min(decay, (1+n)/(10+n)). Lets the EMA track rapid
|
||||||
|
# early-training changes before settling. ``0`` = use ``decay``
|
||||||
|
# from step 1.
|
||||||
|
warmup_steps: int = 0
|
||||||
|
# When True, the periodic eval block uses EMA weights (via a
|
||||||
|
# context-managed swap that restores the live weights on exit).
|
||||||
|
# Standard practice for diffusion-style policies — eval scores
|
||||||
|
# are usually 1–3% higher than the live policy at the same step.
|
||||||
|
use_for_eval: bool = True
|
||||||
|
# When True, the periodic wandb training-example dump uses EMA
|
||||||
|
# weights for the optional predicted-action columns (so what you
|
||||||
|
# see in W&B matches eval behavior).
|
||||||
|
use_for_wandb_examples: bool = True
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EvalConfig:
|
class EvalConfig:
|
||||||
n_episodes: int = 50
|
n_episodes: int = 50
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from lerobot.utils.hub import HubMixin
|
|||||||
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||||
|
|
||||||
from . import parser
|
from . import parser
|
||||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
from .default import DatasetConfig, EMAConfig, EvalConfig, PeftConfig, WandBConfig
|
||||||
from .policies import PreTrainedConfig
|
from .policies import PreTrainedConfig
|
||||||
from .rewards import RewardModelConfig
|
from .rewards import RewardModelConfig
|
||||||
|
|
||||||
@@ -111,6 +111,7 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
scheduler: LRSchedulerConfig | None = None
|
scheduler: LRSchedulerConfig | None = None
|
||||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||||
|
ema: EMAConfig = field(default_factory=EMAConfig)
|
||||||
peft: PeftConfig | None = None
|
peft: PeftConfig | None = None
|
||||||
|
|
||||||
# VQA oversampling. When set (a fraction in (0, 1)), the training
|
# VQA oversampling. When set (a fraction in (0, 1)), the training
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ from lerobot.optim.factory import make_optimizer_and_scheduler
|
|||||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||||
from lerobot.rewards import make_reward_pre_post_processors
|
from lerobot.rewards import make_reward_pre_post_processors
|
||||||
from lerobot.utils.collate import lerobot_collate_fn
|
from lerobot.utils.collate import lerobot_collate_fn
|
||||||
|
from lerobot.utils.ema import ModelEMA
|
||||||
from lerobot.utils.import_utils import register_third_party_plugins
|
from lerobot.utils.import_utils import register_third_party_plugins
|
||||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||||
from lerobot.utils.random_utils import set_seed
|
from lerobot.utils.random_utils import set_seed
|
||||||
@@ -654,6 +655,49 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
|
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# EMA setup
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Maintain a shadow copy of the trainable params for late-training
|
||||||
|
# averaging (Chi et al. 2023 Diffusion Policy §V.D; standard for
|
||||||
|
# diffusion/flow-matching policies). Lives only on the main process —
|
||||||
|
# accelerator broadcasts updates to other ranks naturally via the
|
||||||
|
# live model. Off by default; opt in with ``--ema.enable=true``.
|
||||||
|
ema: ModelEMA | None = None
|
||||||
|
if cfg.ema.enable and is_main_process:
|
||||||
|
ema = ModelEMA(
|
||||||
|
accelerator.unwrap_model(policy),
|
||||||
|
decay=cfg.ema.decay,
|
||||||
|
warmup_steps=cfg.ema.warmup_steps,
|
||||||
|
)
|
||||||
|
logging.info(
|
||||||
|
"EMA enabled: decay=%g, warmup_steps=%d, use_for_eval=%s, "
|
||||||
|
"use_for_wandb_examples=%s",
|
||||||
|
cfg.ema.decay,
|
||||||
|
cfg.ema.warmup_steps,
|
||||||
|
cfg.ema.use_for_eval,
|
||||||
|
cfg.ema.use_for_wandb_examples,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resume the EMA shadow if a previous run wrote one.
|
||||||
|
if cfg.checkpoint_path is not None:
|
||||||
|
ema_path = (
|
||||||
|
cfg.checkpoint_path / "training_state" / "ema_state.safetensors"
|
||||||
|
)
|
||||||
|
if ema_path.exists():
|
||||||
|
logging.info("Resuming EMA shadow from %s", ema_path)
|
||||||
|
try:
|
||||||
|
ema = ModelEMA.load_from_file(
|
||||||
|
accelerator.unwrap_model(policy),
|
||||||
|
ema_path,
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logging.warning(
|
||||||
|
"Failed to load EMA shadow (%s) — restarting EMA from "
|
||||||
|
"current live weights",
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
train_metrics = {
|
train_metrics = {
|
||||||
"loss": AverageMeter("loss", ":.3f"),
|
"loss": AverageMeter("loss", ":.3f"),
|
||||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||||
@@ -706,6 +750,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
sample_weighter=sample_weighter,
|
sample_weighter=sample_weighter,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# EMA update: pull one step of the live weights into the shadow.
|
||||||
|
# Runs only on the main process (the shadow lives there); other
|
||||||
|
# ranks rely on the live model staying in sync via accelerator.
|
||||||
|
# Returns the effective decay used (interesting during warmup).
|
||||||
|
if ema is not None:
|
||||||
|
ema_effective_decay = ema.update(accelerator.unwrap_model(policy))
|
||||||
|
else:
|
||||||
|
ema_effective_decay = None
|
||||||
|
|
||||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||||
# increment `step` here.
|
# increment `step` here.
|
||||||
step += 1
|
step += 1
|
||||||
@@ -747,6 +800,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
if sample_weighter is not None:
|
if sample_weighter is not None:
|
||||||
weighter_stats = sample_weighter.get_stats()
|
weighter_stats = sample_weighter.get_stats()
|
||||||
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
||||||
|
# EMA observability: the effective decay differs from
|
||||||
|
# cfg.ema.decay during warmup; ``num_updates`` lets you
|
||||||
|
# confirm the EMA is actually firing.
|
||||||
|
if ema is not None:
|
||||||
|
wandb_log_dict["ema/effective_decay"] = float(ema_effective_decay or ema.decay)
|
||||||
|
wandb_log_dict["ema/num_updates"] = int(ema.num_updates)
|
||||||
wandb_logger.log_dict(wandb_log_dict, step)
|
wandb_logger.log_dict(wandb_log_dict, step)
|
||||||
train_tracker.reset_averages()
|
train_tracker.reset_averages()
|
||||||
|
|
||||||
@@ -761,14 +820,23 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
and is_main_process
|
and is_main_process
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
wandb_logger.log_training_examples(
|
# Optionally show the EMA model's predictions in the W&B
|
||||||
batch=batch,
|
# examples table (matches the policy you'd actually
|
||||||
step=step,
|
# deploy / eval with). Live weights are restored on exit.
|
||||||
camera_keys=list(dataset.meta.camera_keys),
|
wandb_ema_ctx = (
|
||||||
n_samples=cfg.wandb.log_examples_n,
|
ema.apply_to(accelerator.unwrap_model(policy))
|
||||||
policy=accelerator.unwrap_model(policy),
|
if ema is not None and cfg.ema.use_for_wandb_examples
|
||||||
predict_actions=cfg.wandb.log_examples_predict_actions,
|
else nullcontext()
|
||||||
)
|
)
|
||||||
|
with wandb_ema_ctx:
|
||||||
|
wandb_logger.log_training_examples(
|
||||||
|
batch=batch,
|
||||||
|
step=step,
|
||||||
|
camera_keys=list(dataset.meta.camera_keys),
|
||||||
|
n_samples=cfg.wandb.log_examples_n,
|
||||||
|
policy=accelerator.unwrap_model(policy),
|
||||||
|
predict_actions=cfg.wandb.log_examples_predict_actions,
|
||||||
|
)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
logging.warning("wandb log_training_examples failed: %s", exc)
|
logging.warning("wandb log_training_examples failed: %s", exc)
|
||||||
|
|
||||||
@@ -787,6 +855,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
)
|
)
|
||||||
update_last_checkpoint(checkpoint_dir)
|
update_last_checkpoint(checkpoint_dir)
|
||||||
|
# Save the EMA shadow alongside the training state so a
|
||||||
|
# resumed run picks up exactly where the live EMA left off.
|
||||||
|
if ema is not None:
|
||||||
|
try:
|
||||||
|
ema.save(
|
||||||
|
checkpoint_dir / "training_state" / "ema_state.safetensors"
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logging.warning("Failed to save EMA shadow: %s", exc)
|
||||||
if wandb_logger:
|
if wandb_logger:
|
||||||
wandb_logger.log_policy(checkpoint_dir)
|
wandb_logger.log_policy(checkpoint_dir)
|
||||||
|
|
||||||
@@ -796,7 +873,16 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
if is_main_process:
|
if is_main_process:
|
||||||
step_id = get_step_identifier(step, cfg.steps)
|
step_id = get_step_identifier(step, cfg.steps)
|
||||||
logging.info(f"Eval policy at step {step}")
|
logging.info(f"Eval policy at step {step}")
|
||||||
with torch.no_grad(), accelerator.autocast():
|
# Optionally swap in EMA weights for eval — standard
|
||||||
|
# practice for diffusion-style policies (~1–3% lift on
|
||||||
|
# closed-loop success). Live weights are restored byte-
|
||||||
|
# for-byte on context exit.
|
||||||
|
eval_ema_ctx = (
|
||||||
|
ema.apply_to(accelerator.unwrap_model(policy))
|
||||||
|
if ema is not None and cfg.ema.use_for_eval
|
||||||
|
else nullcontext()
|
||||||
|
)
|
||||||
|
with eval_ema_ctx, torch.no_grad(), accelerator.autocast():
|
||||||
eval_info = eval_policy_all(
|
eval_info = eval_policy_all(
|
||||||
envs=eval_env, # dict[suite][task_id] -> vec_env
|
envs=eval_env, # dict[suite][task_id] -> vec_env
|
||||||
policy=accelerator.unwrap_model(policy),
|
policy=accelerator.unwrap_model(policy),
|
||||||
|
|||||||
@@ -0,0 +1,250 @@
|
|||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Exponential Moving Average of model parameters for training stability.
|
||||||
|
|
||||||
|
Maintains a shadow copy of every trainable parameter, updated after each
|
||||||
|
optimizer step::
|
||||||
|
|
||||||
|
θ_ema ← β · θ_ema + (1 - β) · θ_live
|
||||||
|
|
||||||
|
At eval / inference / final checkpoint, use ``θ_ema`` instead of
|
||||||
|
``θ_live``. For diffusion / flow-matching policies, averaging late-
|
||||||
|
training oscillations yields a smoother model that generalises
|
||||||
|
substantially better at inference — see Chi et al. 2023 (Diffusion
|
||||||
|
Policy §V.D, β=0.75), Ho et al. 2020 (DDPM appendix). For VLAs with a
|
||||||
|
flow-matching action expert the same logic applies: flow gradients have
|
||||||
|
high variance per sample (different noise levels in the same batch),
|
||||||
|
so EMA smooths over that variance.
|
||||||
|
|
||||||
|
Cost: 1× model parameters in fp32 (~13 GB for pi052's 3.3B params),
|
||||||
|
plus one elementwise update per training step (~1% of step time).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Iterator
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
|
|
||||||
|
class ModelEMA:
|
||||||
|
"""Exponential moving average of trainable model parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The live model whose parameter shapes/names define the
|
||||||
|
shadow. Only parameters with ``requires_grad=True`` are
|
||||||
|
tracked. Buffers are intentionally NOT tracked (LayerNorm
|
||||||
|
running stats, RoPE caches, etc.) — they are updated in
|
||||||
|
``train()`` mode regardless of which weights we apply.
|
||||||
|
decay: Target EMA decay (``β`` in
|
||||||
|
``θ_ema ← β·θ_ema + (1-β)·θ_live``). Typical values:
|
||||||
|
|
||||||
|
* ``0.999`` — averages roughly the last 1000 steps. Standard
|
||||||
|
for diffusion-style policies and the default here.
|
||||||
|
* ``0.75`` — very fast EMA, used by Diffusion Policy (Chi
|
||||||
|
et al. 2023). Useful when training is short or noisy.
|
||||||
|
* ``0.9999`` — very slow EMA, used in image classification
|
||||||
|
for very long runs.
|
||||||
|
warmup_steps: If > 0, ramp the effective decay from a low value
|
||||||
|
up to ``decay`` over the first ``warmup_steps`` updates as
|
||||||
|
``min(decay, (1 + n) / (10 + n))``. Lets the EMA track
|
||||||
|
rapid early-training changes before settling on the target.
|
||||||
|
device: Where to keep the shadow parameters. ``None`` keeps each
|
||||||
|
shadow on its parameter's device (good for FSDP / multi-GPU).
|
||||||
|
Pass an explicit device to relocate everything (e.g. ``"cpu"``
|
||||||
|
to free GPU memory at the cost of slower updates).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: nn.Module,
|
||||||
|
*,
|
||||||
|
decay: float = 0.999,
|
||||||
|
warmup_steps: int = 0,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
) -> None:
|
||||||
|
if not 0.0 < decay < 1.0:
|
||||||
|
raise ValueError(f"decay must be in (0, 1), got {decay}")
|
||||||
|
self.decay = float(decay)
|
||||||
|
self.warmup_steps = int(warmup_steps)
|
||||||
|
self.num_updates = 0
|
||||||
|
self.device = torch.device(device) if device is not None else None
|
||||||
|
|
||||||
|
# fp32 shadow — small EMA updates lose precision in bf16.
|
||||||
|
self.shadow: dict[str, torch.Tensor] = {}
|
||||||
|
for name, p in model.named_parameters():
|
||||||
|
if not p.requires_grad:
|
||||||
|
continue
|
||||||
|
shadow_p = p.detach().clone().float()
|
||||||
|
if self.device is not None:
|
||||||
|
shadow_p = shadow_p.to(self.device)
|
||||||
|
self.shadow[name] = shadow_p
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Core update
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _effective_decay(self) -> float:
|
||||||
|
if self.warmup_steps <= 0 or self.num_updates >= self.warmup_steps:
|
||||||
|
return self.decay
|
||||||
|
# Standard EMA warmup (timm / diffusers convention): grows
|
||||||
|
# 0.09, 0.16, 0.23, ... and saturates at ``decay``.
|
||||||
|
return min(self.decay, (1.0 + self.num_updates) / (10.0 + self.num_updates))
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def update(self, model: nn.Module) -> float:
|
||||||
|
"""Pull one update from the live model into the shadow.
|
||||||
|
|
||||||
|
Returns the effective decay used this step (useful to log during
|
||||||
|
warmup, when the value differs from ``self.decay``).
|
||||||
|
"""
|
||||||
|
self.num_updates += 1
|
||||||
|
beta = self._effective_decay()
|
||||||
|
one_minus_beta = 1.0 - beta
|
||||||
|
for name, p in model.named_parameters():
|
||||||
|
if not p.requires_grad:
|
||||||
|
continue
|
||||||
|
shadow = self.shadow.get(name)
|
||||||
|
if shadow is None:
|
||||||
|
# New parameter appeared mid-training — seed it.
|
||||||
|
shadow = p.detach().clone().float()
|
||||||
|
if self.device is not None:
|
||||||
|
shadow = shadow.to(self.device)
|
||||||
|
self.shadow[name] = shadow
|
||||||
|
continue
|
||||||
|
# In-place fused: shadow ← β · shadow + (1 - β) · p
|
||||||
|
shadow.mul_(beta).add_(
|
||||||
|
p.detach().to(shadow.device, dtype=torch.float32),
|
||||||
|
alpha=one_minus_beta,
|
||||||
|
)
|
||||||
|
return beta
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Applying the EMA to the live model
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def copy_to(self, model: nn.Module) -> None:
|
||||||
|
"""Overwrite the live model's parameters with the EMA shadow.
|
||||||
|
|
||||||
|
In-place and **irreversible** — the previous live weights are
|
||||||
|
lost. Use this only at the very end of training when you want
|
||||||
|
the EMA to *be* the final saved policy. For temporary swaps
|
||||||
|
(e.g. during eval), use :meth:`apply_to`.
|
||||||
|
"""
|
||||||
|
for name, p in model.named_parameters():
|
||||||
|
shadow = self.shadow.get(name)
|
||||||
|
if shadow is not None:
|
||||||
|
p.data.copy_(shadow.to(p.device, dtype=p.dtype))
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def apply_to(self, model: nn.Module) -> Iterator[None]:
|
||||||
|
"""Temporarily swap the live model's weights with the EMA copy.
|
||||||
|
|
||||||
|
On exit, the original live weights are restored byte-for-byte
|
||||||
|
(we keep a backup clone of every tracked parameter inside the
|
||||||
|
context). Use this around eval / sample-logging without
|
||||||
|
disturbing the live training state::
|
||||||
|
|
||||||
|
with ema.apply_to(policy):
|
||||||
|
eval_metrics = evaluate(policy)
|
||||||
|
# policy is back to its pre-eval state here.
|
||||||
|
"""
|
||||||
|
backup: dict[str, torch.Tensor] = {}
|
||||||
|
for name, p in model.named_parameters():
|
||||||
|
if name in self.shadow:
|
||||||
|
backup[name] = p.detach().clone()
|
||||||
|
p.data.copy_(self.shadow[name].to(p.device, dtype=p.dtype))
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
for name, p in model.named_parameters():
|
||||||
|
if name in backup:
|
||||||
|
p.data.copy_(backup[name].to(p.device, dtype=p.dtype))
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Checkpointing
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def state_dict(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"shadow": {k: v.detach().cpu() for k, v in self.shadow.items()},
|
||||||
|
"num_updates": self.num_updates,
|
||||||
|
"decay": self.decay,
|
||||||
|
"warmup_steps": self.warmup_steps,
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_state_dict(self, state: dict[str, Any]) -> None:
|
||||||
|
self.decay = float(state["decay"])
|
||||||
|
self.warmup_steps = int(state["warmup_steps"])
|
||||||
|
self.num_updates = int(state["num_updates"])
|
||||||
|
new_shadow: dict[str, torch.Tensor] = {}
|
||||||
|
for k, v in state["shadow"].items():
|
||||||
|
t = v.detach()
|
||||||
|
if self.device is not None:
|
||||||
|
t = t.to(self.device)
|
||||||
|
new_shadow[k] = t.float()
|
||||||
|
self.shadow = new_shadow
|
||||||
|
|
||||||
|
def save(self, path: Path | str) -> None:
|
||||||
|
"""Save the shadow as safetensors + a tiny JSON sidecar with metadata.
|
||||||
|
|
||||||
|
Sidecar lives at ``<path>.json`` and stores ``num_updates``,
|
||||||
|
``decay``, ``warmup_steps`` — enough to resume exact EMA state.
|
||||||
|
"""
|
||||||
|
path = Path(path)
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
save_file(
|
||||||
|
{k: v.detach().cpu().contiguous() for k, v in self.shadow.items()},
|
||||||
|
str(path),
|
||||||
|
)
|
||||||
|
meta = {
|
||||||
|
"num_updates": self.num_updates,
|
||||||
|
"decay": self.decay,
|
||||||
|
"warmup_steps": self.warmup_steps,
|
||||||
|
}
|
||||||
|
path.with_suffix(path.suffix + ".json").write_text(json.dumps(meta, indent=2))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_from_file(
|
||||||
|
cls,
|
||||||
|
model: nn.Module,
|
||||||
|
path: Path | str,
|
||||||
|
*,
|
||||||
|
device: torch.device | str | None = None,
|
||||||
|
) -> "ModelEMA":
|
||||||
|
"""Reconstruct a ``ModelEMA`` from a previously-saved safetensors + sidecar pair."""
|
||||||
|
path = Path(path)
|
||||||
|
meta_path = path.with_suffix(path.suffix + ".json")
|
||||||
|
meta = json.loads(meta_path.read_text()) if meta_path.exists() else {}
|
||||||
|
ema = cls(
|
||||||
|
model,
|
||||||
|
decay=float(meta.get("decay", 0.999)),
|
||||||
|
warmup_steps=int(meta.get("warmup_steps", 0)),
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
shadow = load_file(str(path))
|
||||||
|
target_device = ema.device
|
||||||
|
ema.shadow = {
|
||||||
|
k: (v.to(target_device) if target_device is not None else v).float()
|
||||||
|
for k, v in shadow.items()
|
||||||
|
}
|
||||||
|
ema.num_updates = int(meta.get("num_updates", 0))
|
||||||
|
return ema
|
||||||
Reference in New Issue
Block a user