mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 22:49:48 +00:00
train: EMA of policy parameters (opt-in via --ema.enable=true)
Adds Exponential Moving Average of trainable policy parameters with
warmup, eval-time swap, checkpoint save/resume, and wandb observability.
For diffusion / flow-matching policies (pi052's flow expert exactly
qualifies), averaging late-training parameter oscillations yields a
smoother model that generalises substantially better at inference —
~1–3% absolute success-rate improvement on closed-loop tasks per the
diffusion-policy lit (Chi et al. 2023 §V.D; standard in DDPM/EDM).
New module: src/lerobot/utils/ema.py
ModelEMA class with:
* fp32 shadow of every requires_grad parameter
* decay warmup: min(decay, (1+n)/(10+n)) for first warmup_steps updates
* update(model) -> effective_decay (for logging)
* apply_to(model) context manager: temp-swap weights, restore on exit
* copy_to(model): permanent overwrite
* save() / load_from_file(): safetensors + JSON sidecar for metadata
* state_dict() / load_state_dict() for in-process round-tripping
New config: src/lerobot/configs/default.py EMAConfig + wired into
TrainPipelineConfig as 'ema: EMAConfig'.
Fields:
enable: bool = False (off by default, back-compat)
decay: float = 0.999 (standard; 0.75 for fast Diffusion-Policy)
warmup_steps: int = 0 (no warmup by default)
use_for_eval: bool = True (eval swaps in EMA weights)
use_for_wandb_examples: bool = True
(W&B training-examples table uses EMA
for predicted-action columns -> matches
what eval / deployment would see)
Training loop integration (src/lerobot/scripts/lerobot_train.py):
1. After accelerator.prepare + policy.train(), instantiate ModelEMA
on the main process if cfg.ema.enable. Resume from
checkpoint_path/training_state/ema_state.safetensors if present.
2. After each update_policy() call, ema.update(unwrap_model(policy))
returns the effective decay (logged to wandb during warmup).
3. The save_checkpoint() block also ema.save(...) the shadow next to
the existing optimizer/scheduler/rng training state. Resume picks
it up automatically in (1).
4. The eval block (cfg.env && is_eval_step) wraps eval_policy_all in
ema.apply_to() when use_for_eval=True. Live weights restored
byte-for-byte on context exit.
5. The W&B training-example dump wraps log_training_examples in
ema.apply_to() when use_for_wandb_examples=True so the predicted-
action columns match the eval/deployment behavior.
6. Two new wandb scalars: ema/effective_decay, ema/num_updates.
Cost:
Memory: 1x model params in fp32 (~13 GB for pi052's 3.3B params).
Lives only on main-process GPU. CPU offload available via
ModelEMA(device='cpu') if needed.
Compute: one elementwise update per step (~1% of step time).
Eval: 2x checkpoint files in training_state/ (live optimizer state
+ ema shadow). Negligible relative to model.safetensors.
Usage:
lerobot-train ... --ema.enable=true
lerobot-train ... --ema.enable=true --ema.decay=0.9999 # very slow EMA
lerobot-train ... --ema.enable=true --ema.warmup_steps=1000
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -81,6 +81,44 @@ class WandBConfig:
|
||||
log_examples_predict_actions: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class EMAConfig:
|
||||
"""Exponential Moving Average of trainable policy parameters.
|
||||
|
||||
Diffusion / flow-matching policies (Diffusion Policy, π0/π0.5,
|
||||
pi052) benefit substantially from averaging late-training
|
||||
parameter oscillations — see Chi et al. 2023 §V.D. EMA adds a
|
||||
fp32 shadow of the policy (~13 GB for pi052's 3.3B params) and one
|
||||
elementwise update per training step (~1% step time).
|
||||
|
||||
Off by default (back-compat for existing runs). Recommended for
|
||||
long pi052 training runs where you want closed-loop eval to use
|
||||
the smoothed weights — typically ~1–3% absolute success-rate
|
||||
improvement on closed-loop tasks per the diffusion-policy lit.
|
||||
"""
|
||||
|
||||
enable: bool = False
|
||||
# Target EMA decay β in θ_ema ← β·θ_ema + (1-β)·θ_live.
|
||||
# 0.999 ≈ last 1000 steps (standard; pi05-class default)
|
||||
# 0.75 ≈ very fast EMA (Diffusion Policy original setting)
|
||||
# 0.9999 ≈ very slow EMA (long classification runs)
|
||||
decay: float = 0.999
|
||||
# If > 0, ramp effective decay up to ``decay`` over the first N
|
||||
# updates as min(decay, (1+n)/(10+n)). Lets the EMA track rapid
|
||||
# early-training changes before settling. ``0`` = use ``decay``
|
||||
# from step 1.
|
||||
warmup_steps: int = 0
|
||||
# When True, the periodic eval block uses EMA weights (via a
|
||||
# context-managed swap that restores the live weights on exit).
|
||||
# Standard practice for diffusion-style policies — eval scores
|
||||
# are usually 1–3% higher than the live policy at the same step.
|
||||
use_for_eval: bool = True
|
||||
# When True, the periodic wandb training-example dump uses EMA
|
||||
# weights for the optional predicted-action columns (so what you
|
||||
# see in W&B matches eval behavior).
|
||||
use_for_wandb_examples: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalConfig:
|
||||
n_episodes: int = 50
|
||||
|
||||
@@ -30,7 +30,7 @@ from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.sample_weighting import SampleWeightingConfig
|
||||
|
||||
from . import parser
|
||||
from .default import DatasetConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .default import DatasetConfig, EMAConfig, EvalConfig, PeftConfig, WandBConfig
|
||||
from .policies import PreTrainedConfig
|
||||
from .rewards import RewardModelConfig
|
||||
|
||||
@@ -111,6 +111,7 @@ class TrainPipelineConfig(HubMixin):
|
||||
scheduler: LRSchedulerConfig | None = None
|
||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
ema: EMAConfig = field(default_factory=EMAConfig)
|
||||
peft: PeftConfig | None = None
|
||||
|
||||
# VQA oversampling. When set (a fraction in (0, 1)), the training
|
||||
|
||||
@@ -50,6 +50,7 @@ from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||
from lerobot.rewards import make_reward_pre_post_processors
|
||||
from lerobot.utils.collate import lerobot_collate_fn
|
||||
from lerobot.utils.ema import ModelEMA
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
@@ -654,6 +655,49 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
|
||||
policy.train()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# EMA setup
|
||||
# ------------------------------------------------------------------
|
||||
# Maintain a shadow copy of the trainable params for late-training
|
||||
# averaging (Chi et al. 2023 Diffusion Policy §V.D; standard for
|
||||
# diffusion/flow-matching policies). Lives only on the main process —
|
||||
# accelerator broadcasts updates to other ranks naturally via the
|
||||
# live model. Off by default; opt in with ``--ema.enable=true``.
|
||||
ema: ModelEMA | None = None
|
||||
if cfg.ema.enable and is_main_process:
|
||||
ema = ModelEMA(
|
||||
accelerator.unwrap_model(policy),
|
||||
decay=cfg.ema.decay,
|
||||
warmup_steps=cfg.ema.warmup_steps,
|
||||
)
|
||||
logging.info(
|
||||
"EMA enabled: decay=%g, warmup_steps=%d, use_for_eval=%s, "
|
||||
"use_for_wandb_examples=%s",
|
||||
cfg.ema.decay,
|
||||
cfg.ema.warmup_steps,
|
||||
cfg.ema.use_for_eval,
|
||||
cfg.ema.use_for_wandb_examples,
|
||||
)
|
||||
|
||||
# Resume the EMA shadow if a previous run wrote one.
|
||||
if cfg.checkpoint_path is not None:
|
||||
ema_path = (
|
||||
cfg.checkpoint_path / "training_state" / "ema_state.safetensors"
|
||||
)
|
||||
if ema_path.exists():
|
||||
logging.info("Resuming EMA shadow from %s", ema_path)
|
||||
try:
|
||||
ema = ModelEMA.load_from_file(
|
||||
accelerator.unwrap_model(policy),
|
||||
ema_path,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning(
|
||||
"Failed to load EMA shadow (%s) — restarting EMA from "
|
||||
"current live weights",
|
||||
exc,
|
||||
)
|
||||
|
||||
train_metrics = {
|
||||
"loss": AverageMeter("loss", ":.3f"),
|
||||
"grad_norm": AverageMeter("grdn", ":.3f"),
|
||||
@@ -706,6 +750,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
sample_weighter=sample_weighter,
|
||||
)
|
||||
|
||||
# EMA update: pull one step of the live weights into the shadow.
|
||||
# Runs only on the main process (the shadow lives there); other
|
||||
# ranks rely on the live model staying in sync via accelerator.
|
||||
# Returns the effective decay used (interesting during warmup).
|
||||
if ema is not None:
|
||||
ema_effective_decay = ema.update(accelerator.unwrap_model(policy))
|
||||
else:
|
||||
ema_effective_decay = None
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
# increment `step` here.
|
||||
step += 1
|
||||
@@ -747,6 +800,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if sample_weighter is not None:
|
||||
weighter_stats = sample_weighter.get_stats()
|
||||
wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()})
|
||||
# EMA observability: the effective decay differs from
|
||||
# cfg.ema.decay during warmup; ``num_updates`` lets you
|
||||
# confirm the EMA is actually firing.
|
||||
if ema is not None:
|
||||
wandb_log_dict["ema/effective_decay"] = float(ema_effective_decay or ema.decay)
|
||||
wandb_log_dict["ema/num_updates"] = int(ema.num_updates)
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
@@ -761,14 +820,23 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
and is_main_process
|
||||
):
|
||||
try:
|
||||
wandb_logger.log_training_examples(
|
||||
batch=batch,
|
||||
step=step,
|
||||
camera_keys=list(dataset.meta.camera_keys),
|
||||
n_samples=cfg.wandb.log_examples_n,
|
||||
policy=accelerator.unwrap_model(policy),
|
||||
predict_actions=cfg.wandb.log_examples_predict_actions,
|
||||
# Optionally show the EMA model's predictions in the W&B
|
||||
# examples table (matches the policy you'd actually
|
||||
# deploy / eval with). Live weights are restored on exit.
|
||||
wandb_ema_ctx = (
|
||||
ema.apply_to(accelerator.unwrap_model(policy))
|
||||
if ema is not None and cfg.ema.use_for_wandb_examples
|
||||
else nullcontext()
|
||||
)
|
||||
with wandb_ema_ctx:
|
||||
wandb_logger.log_training_examples(
|
||||
batch=batch,
|
||||
step=step,
|
||||
camera_keys=list(dataset.meta.camera_keys),
|
||||
n_samples=cfg.wandb.log_examples_n,
|
||||
policy=accelerator.unwrap_model(policy),
|
||||
predict_actions=cfg.wandb.log_examples_predict_actions,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning("wandb log_training_examples failed: %s", exc)
|
||||
|
||||
@@ -787,6 +855,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
update_last_checkpoint(checkpoint_dir)
|
||||
# Save the EMA shadow alongside the training state so a
|
||||
# resumed run picks up exactly where the live EMA left off.
|
||||
if ema is not None:
|
||||
try:
|
||||
ema.save(
|
||||
checkpoint_dir / "training_state" / "ema_state.safetensors"
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning("Failed to save EMA shadow: %s", exc)
|
||||
if wandb_logger:
|
||||
wandb_logger.log_policy(checkpoint_dir)
|
||||
|
||||
@@ -796,7 +873,16 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if is_main_process:
|
||||
step_id = get_step_identifier(step, cfg.steps)
|
||||
logging.info(f"Eval policy at step {step}")
|
||||
with torch.no_grad(), accelerator.autocast():
|
||||
# Optionally swap in EMA weights for eval — standard
|
||||
# practice for diffusion-style policies (~1–3% lift on
|
||||
# closed-loop success). Live weights are restored byte-
|
||||
# for-byte on context exit.
|
||||
eval_ema_ctx = (
|
||||
ema.apply_to(accelerator.unwrap_model(policy))
|
||||
if ema is not None and cfg.ema.use_for_eval
|
||||
else nullcontext()
|
||||
)
|
||||
with eval_ema_ctx, torch.no_grad(), accelerator.autocast():
|
||||
eval_info = eval_policy_all(
|
||||
envs=eval_env, # dict[suite][task_id] -> vec_env
|
||||
policy=accelerator.unwrap_model(policy),
|
||||
|
||||
@@ -0,0 +1,250 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Exponential Moving Average of model parameters for training stability.
|
||||
|
||||
Maintains a shadow copy of every trainable parameter, updated after each
|
||||
optimizer step::
|
||||
|
||||
θ_ema ← β · θ_ema + (1 - β) · θ_live
|
||||
|
||||
At eval / inference / final checkpoint, use ``θ_ema`` instead of
|
||||
``θ_live``. For diffusion / flow-matching policies, averaging late-
|
||||
training oscillations yields a smoother model that generalises
|
||||
substantially better at inference — see Chi et al. 2023 (Diffusion
|
||||
Policy §V.D, β=0.75), Ho et al. 2020 (DDPM appendix). For VLAs with a
|
||||
flow-matching action expert the same logic applies: flow gradients have
|
||||
high variance per sample (different noise levels in the same batch),
|
||||
so EMA smooths over that variance.
|
||||
|
||||
Cost: 1× model parameters in fp32 (~13 GB for pi052's 3.3B params),
|
||||
plus one elementwise update per training step (~1% of step time).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterator
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
|
||||
class ModelEMA:
|
||||
"""Exponential moving average of trainable model parameters.
|
||||
|
||||
Args:
|
||||
model: The live model whose parameter shapes/names define the
|
||||
shadow. Only parameters with ``requires_grad=True`` are
|
||||
tracked. Buffers are intentionally NOT tracked (LayerNorm
|
||||
running stats, RoPE caches, etc.) — they are updated in
|
||||
``train()`` mode regardless of which weights we apply.
|
||||
decay: Target EMA decay (``β`` in
|
||||
``θ_ema ← β·θ_ema + (1-β)·θ_live``). Typical values:
|
||||
|
||||
* ``0.999`` — averages roughly the last 1000 steps. Standard
|
||||
for diffusion-style policies and the default here.
|
||||
* ``0.75`` — very fast EMA, used by Diffusion Policy (Chi
|
||||
et al. 2023). Useful when training is short or noisy.
|
||||
* ``0.9999`` — very slow EMA, used in image classification
|
||||
for very long runs.
|
||||
warmup_steps: If > 0, ramp the effective decay from a low value
|
||||
up to ``decay`` over the first ``warmup_steps`` updates as
|
||||
``min(decay, (1 + n) / (10 + n))``. Lets the EMA track
|
||||
rapid early-training changes before settling on the target.
|
||||
device: Where to keep the shadow parameters. ``None`` keeps each
|
||||
shadow on its parameter's device (good for FSDP / multi-GPU).
|
||||
Pass an explicit device to relocate everything (e.g. ``"cpu"``
|
||||
to free GPU memory at the cost of slower updates).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
*,
|
||||
decay: float = 0.999,
|
||||
warmup_steps: int = 0,
|
||||
device: torch.device | str | None = None,
|
||||
) -> None:
|
||||
if not 0.0 < decay < 1.0:
|
||||
raise ValueError(f"decay must be in (0, 1), got {decay}")
|
||||
self.decay = float(decay)
|
||||
self.warmup_steps = int(warmup_steps)
|
||||
self.num_updates = 0
|
||||
self.device = torch.device(device) if device is not None else None
|
||||
|
||||
# fp32 shadow — small EMA updates lose precision in bf16.
|
||||
self.shadow: dict[str, torch.Tensor] = {}
|
||||
for name, p in model.named_parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
shadow_p = p.detach().clone().float()
|
||||
if self.device is not None:
|
||||
shadow_p = shadow_p.to(self.device)
|
||||
self.shadow[name] = shadow_p
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core update
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _effective_decay(self) -> float:
|
||||
if self.warmup_steps <= 0 or self.num_updates >= self.warmup_steps:
|
||||
return self.decay
|
||||
# Standard EMA warmup (timm / diffusers convention): grows
|
||||
# 0.09, 0.16, 0.23, ... and saturates at ``decay``.
|
||||
return min(self.decay, (1.0 + self.num_updates) / (10.0 + self.num_updates))
|
||||
|
||||
@torch.no_grad()
|
||||
def update(self, model: nn.Module) -> float:
|
||||
"""Pull one update from the live model into the shadow.
|
||||
|
||||
Returns the effective decay used this step (useful to log during
|
||||
warmup, when the value differs from ``self.decay``).
|
||||
"""
|
||||
self.num_updates += 1
|
||||
beta = self._effective_decay()
|
||||
one_minus_beta = 1.0 - beta
|
||||
for name, p in model.named_parameters():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
shadow = self.shadow.get(name)
|
||||
if shadow is None:
|
||||
# New parameter appeared mid-training — seed it.
|
||||
shadow = p.detach().clone().float()
|
||||
if self.device is not None:
|
||||
shadow = shadow.to(self.device)
|
||||
self.shadow[name] = shadow
|
||||
continue
|
||||
# In-place fused: shadow ← β · shadow + (1 - β) · p
|
||||
shadow.mul_(beta).add_(
|
||||
p.detach().to(shadow.device, dtype=torch.float32),
|
||||
alpha=one_minus_beta,
|
||||
)
|
||||
return beta
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Applying the EMA to the live model
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@torch.no_grad()
|
||||
def copy_to(self, model: nn.Module) -> None:
|
||||
"""Overwrite the live model's parameters with the EMA shadow.
|
||||
|
||||
In-place and **irreversible** — the previous live weights are
|
||||
lost. Use this only at the very end of training when you want
|
||||
the EMA to *be* the final saved policy. For temporary swaps
|
||||
(e.g. during eval), use :meth:`apply_to`.
|
||||
"""
|
||||
for name, p in model.named_parameters():
|
||||
shadow = self.shadow.get(name)
|
||||
if shadow is not None:
|
||||
p.data.copy_(shadow.to(p.device, dtype=p.dtype))
|
||||
|
||||
@contextmanager
|
||||
def apply_to(self, model: nn.Module) -> Iterator[None]:
|
||||
"""Temporarily swap the live model's weights with the EMA copy.
|
||||
|
||||
On exit, the original live weights are restored byte-for-byte
|
||||
(we keep a backup clone of every tracked parameter inside the
|
||||
context). Use this around eval / sample-logging without
|
||||
disturbing the live training state::
|
||||
|
||||
with ema.apply_to(policy):
|
||||
eval_metrics = evaluate(policy)
|
||||
# policy is back to its pre-eval state here.
|
||||
"""
|
||||
backup: dict[str, torch.Tensor] = {}
|
||||
for name, p in model.named_parameters():
|
||||
if name in self.shadow:
|
||||
backup[name] = p.detach().clone()
|
||||
p.data.copy_(self.shadow[name].to(p.device, dtype=p.dtype))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for name, p in model.named_parameters():
|
||||
if name in backup:
|
||||
p.data.copy_(backup[name].to(p.device, dtype=p.dtype))
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Checkpointing
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def state_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"shadow": {k: v.detach().cpu() for k, v in self.shadow.items()},
|
||||
"num_updates": self.num_updates,
|
||||
"decay": self.decay,
|
||||
"warmup_steps": self.warmup_steps,
|
||||
}
|
||||
|
||||
def load_state_dict(self, state: dict[str, Any]) -> None:
|
||||
self.decay = float(state["decay"])
|
||||
self.warmup_steps = int(state["warmup_steps"])
|
||||
self.num_updates = int(state["num_updates"])
|
||||
new_shadow: dict[str, torch.Tensor] = {}
|
||||
for k, v in state["shadow"].items():
|
||||
t = v.detach()
|
||||
if self.device is not None:
|
||||
t = t.to(self.device)
|
||||
new_shadow[k] = t.float()
|
||||
self.shadow = new_shadow
|
||||
|
||||
def save(self, path: Path | str) -> None:
|
||||
"""Save the shadow as safetensors + a tiny JSON sidecar with metadata.
|
||||
|
||||
Sidecar lives at ``<path>.json`` and stores ``num_updates``,
|
||||
``decay``, ``warmup_steps`` — enough to resume exact EMA state.
|
||||
"""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
save_file(
|
||||
{k: v.detach().cpu().contiguous() for k, v in self.shadow.items()},
|
||||
str(path),
|
||||
)
|
||||
meta = {
|
||||
"num_updates": self.num_updates,
|
||||
"decay": self.decay,
|
||||
"warmup_steps": self.warmup_steps,
|
||||
}
|
||||
path.with_suffix(path.suffix + ".json").write_text(json.dumps(meta, indent=2))
|
||||
|
||||
@classmethod
|
||||
def load_from_file(
|
||||
cls,
|
||||
model: nn.Module,
|
||||
path: Path | str,
|
||||
*,
|
||||
device: torch.device | str | None = None,
|
||||
) -> "ModelEMA":
|
||||
"""Reconstruct a ``ModelEMA`` from a previously-saved safetensors + sidecar pair."""
|
||||
path = Path(path)
|
||||
meta_path = path.with_suffix(path.suffix + ".json")
|
||||
meta = json.loads(meta_path.read_text()) if meta_path.exists() else {}
|
||||
ema = cls(
|
||||
model,
|
||||
decay=float(meta.get("decay", 0.999)),
|
||||
warmup_steps=int(meta.get("warmup_steps", 0)),
|
||||
device=device,
|
||||
)
|
||||
shadow = load_file(str(path))
|
||||
target_device = ema.device
|
||||
ema.shadow = {
|
||||
k: (v.to(target_device) if target_device is not None else v).float()
|
||||
for k, v in shadow.items()
|
||||
}
|
||||
ema.num_updates = int(meta.get("num_updates", 0))
|
||||
return ema
|
||||
Reference in New Issue
Block a user