diff --git a/pyproject.toml b/pyproject.toml index fbae1897e..b14dde5e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,6 +85,11 @@ dependencies = [ "termcolor>=2.4.0,<4.0.0", "tqdm>=4.66.0,<5.0.0", + # Training utilities + # EMA of policy parameters (Diffusion Policy / pi05 style). Tiny + # pure-python dependency — preferred over a hand-rolled implementation. + "ema-pytorch>=0.7.7,<1.0.0", + # Build tools (required by opencv-python-headless on some platforms) "cmake>=3.29.0.1,<4.2.0", "setuptools>=71.0.0,<81.0.0", diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index 4e8ac0446..a03c436cf 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -87,35 +87,44 @@ class EMAConfig: Diffusion / flow-matching policies (Diffusion Policy, π0/π0.5, pi052) benefit substantially from averaging late-training - parameter oscillations — see Chi et al. 2023 §V.D. EMA adds a - fp32 shadow of the policy (~13 GB for pi052's 3.3B params) and one - elementwise update per training step (~1% step time). + parameter oscillations — see Chi et al. 2023 §V.D. The official + JAX openpi trainer ships EMA with ``ema_decay=0.99`` (default) and + ``0.999`` for its pi05_libero config; the openpi PyTorch port + explicitly lists EMA as unsupported, and LeRobot main inherited + that gap. Enabling this flag plugs ema-pytorch + (https://github.com/lucidrains/ema-pytorch) into the LeRobot + training loop with a shadow ``nn.Module`` clone of the policy. - Off by default (back-compat for existing runs). Recommended for - long pi052 training runs where you want closed-loop eval to use - the smoothed weights — typically ~1–3% absolute success-rate - improvement on closed-loop tasks per the diffusion-policy lit. + Cost: 1× model params in fp32 shadow (~13 GB for pi052's 3.3B + params) + one elementwise update per training step (~1% step time). + + Off by default (back-compat). Recommended for long pi052 training + runs — typically ~1–3% absolute success-rate improvement on + closed-loop tasks per the diffusion-policy literature. """ enable: bool = False - # Target EMA decay β in θ_ema ← β·θ_ema + (1-β)·θ_live. - # 0.999 ≈ last 1000 steps (standard; pi05-class default) - # 0.75 ≈ very fast EMA (Diffusion Policy original setting) - # 0.9999 ≈ very slow EMA (long classification runs) + # Target EMA decay β in θ_ema ← β·θ_ema + (1-β)·θ_live (passed to + # ema-pytorch as ``beta``). + # 0.999 — last ~1000 steps; pi05_libero default in openpi + # 0.99 — last ~100 steps; openpi top-level default + # 0.75 — very fast EMA (Diffusion Policy original setting) + # 0.9999 — very slow EMA (long classification runs) decay: float = 0.999 - # If > 0, ramp effective decay up to ``decay`` over the first N - # updates as min(decay, (1+n)/(10+n)). Lets the EMA track rapid - # early-training changes before settling. ``0`` = use ``decay`` - # from step 1. + # Skip the first N calls to ``ema.update()``; during this window + # the shadow is just a hard copy of the live weights (no averaging). + # Lets early-training rapid changes settle before averaging begins. + # Maps to ema-pytorch's ``update_after_step`` (NOT a smooth decay + # ramp like older lerobot EMA implementations). warmup_steps: int = 0 - # When True, the periodic eval block uses EMA weights (via a - # context-managed swap that restores the live weights on exit). - # Standard practice for diffusion-style policies — eval scores - # are usually 1–3% higher than the live policy at the same step. + # When True, the periodic eval block uses the EMA shadow model + # directly (``ema.ema_model``) instead of the live policy. Standard + # practice for diffusion-style policies — eval scores are usually + # 1–3% higher than the live policy at the same step. use_for_eval: bool = True - # When True, the periodic wandb training-example dump uses EMA - # weights for the optional predicted-action columns (so what you - # see in W&B matches eval behavior). + # When True, the periodic wandb training-example dump uses the EMA + # shadow for the optional predicted-action columns (so what you see + # in W&B matches eval behavior). use_for_wandb_examples: bool = True diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 3d7dc083a..0b0059955 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -50,7 +50,6 @@ from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors from lerobot.rewards import make_reward_pre_post_processors from lerobot.utils.collate import lerobot_collate_fn -from lerobot.utils.ema import ModelEMA from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.logging_utils import AverageMeter, MetricsTracker from lerobot.utils.random_utils import set_seed @@ -658,21 +657,31 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): # ------------------------------------------------------------------ # EMA setup # ------------------------------------------------------------------ - # Maintain a shadow copy of the trainable params for late-training - # averaging (Chi et al. 2023 Diffusion Policy §V.D; standard for - # diffusion/flow-matching policies). Lives only on the main process — - # accelerator broadcasts updates to other ranks naturally via the - # live model. Off by default; opt in with ``--ema.enable=true``. - ema: ModelEMA | None = None + # Shadow copy of the trainable params for late-training averaging + # (Chi et al. 2023 Diffusion Policy §V.D; openpi JAX trainer ships + # this with decay=0.999 for pi05_libero; openpi PyTorch port and + # LeRobot main both skip it). Off by default; opt in with + # ``--ema.enable=true``. Implemented via ema-pytorch + # (https://github.com/lucidrains/ema-pytorch) — the standard PyTorch + # EMA library, also used by lucidrains' diffusion repos. + ema = None if cfg.ema.enable and is_main_process: - ema = ModelEMA( + from ema_pytorch import EMA # noqa: PLC0415 + + ema = EMA( accelerator.unwrap_model(policy), - decay=cfg.ema.decay, - warmup_steps=cfg.ema.warmup_steps, + beta=cfg.ema.decay, + update_after_step=cfg.ema.warmup_steps, + update_every=1, # update on every ema.update() call + # Don't register the live model as an ema submodule — accelerator + # already owns its lifecycle, and double-registration would + # double-count its params in ``ema.state_dict()``. + include_online_model=False, ) + ema.to(accelerator.device) logging.info( - "EMA enabled: decay=%g, warmup_steps=%d, use_for_eval=%s, " - "use_for_wandb_examples=%s", + "EMA enabled (ema-pytorch): beta=%g, update_after_step=%d, " + "use_for_eval=%s, use_for_wandb_examples=%s", cfg.ema.decay, cfg.ema.warmup_steps, cfg.ema.use_for_eval, @@ -681,16 +690,11 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): # Resume the EMA shadow if a previous run wrote one. if cfg.checkpoint_path is not None: - ema_path = ( - cfg.checkpoint_path / "training_state" / "ema_state.safetensors" - ) + ema_path = cfg.checkpoint_path / "training_state" / "ema_state.pt" if ema_path.exists(): logging.info("Resuming EMA shadow from %s", ema_path) try: - ema = ModelEMA.load_from_file( - accelerator.unwrap_model(policy), - ema_path, - ) + ema.load_state_dict(torch.load(ema_path, map_location=accelerator.device)) except Exception as exc: # noqa: BLE001 logging.warning( "Failed to load EMA shadow (%s) — restarting EMA from " @@ -753,11 +757,10 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): # EMA update: pull one step of the live weights into the shadow. # Runs only on the main process (the shadow lives there); other # ranks rely on the live model staying in sync via accelerator. - # Returns the effective decay used (interesting during warmup). + # ``ema-pytorch`` holds an internal reference to the online model + # (set at construction), so ``ema.update()`` takes no args. if ema is not None: - ema_effective_decay = ema.update(accelerator.unwrap_model(policy)) - else: - ema_effective_decay = None + ema.update() # Note: eval and checkpoint happens *after* the `step`th training update has completed, so we # increment `step` here. @@ -800,12 +803,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): if sample_weighter is not None: weighter_stats = sample_weighter.get_stats() wandb_log_dict.update({f"sample_weighting/{k}": v for k, v in weighter_stats.items()}) - # EMA observability: the effective decay differs from - # cfg.ema.decay during warmup; ``num_updates`` lets you - # confirm the EMA is actually firing. + # EMA observability: ``ema.step`` is the count of + # ``ema.update()`` calls (= optimizer steps once EMA is + # enabled); ``ema.initted`` flips to True once we've + # crossed ``update_after_step``. if ema is not None: - wandb_log_dict["ema/effective_decay"] = float(ema_effective_decay or ema.decay) - wandb_log_dict["ema/num_updates"] = int(ema.num_updates) + wandb_log_dict["ema/step"] = int(ema.step.item()) + wandb_log_dict["ema/initted"] = float(ema.initted.item()) + wandb_log_dict["ema/beta"] = float(cfg.ema.decay) wandb_logger.log_dict(wandb_log_dict, step) train_tracker.reset_averages() @@ -820,23 +825,24 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): and is_main_process ): try: - # Optionally show the EMA model's predictions in the W&B - # examples table (matches the policy you'd actually - # deploy / eval with). Live weights are restored on exit. - wandb_ema_ctx = ( - ema.apply_to(accelerator.unwrap_model(policy)) - if ema is not None and cfg.ema.use_for_wandb_examples - else nullcontext() + # Optionally use the EMA shadow model directly for the + # predicted-action columns (matches what eval / deployment + # would see). ``ema-pytorch`` exposes the shadow as a + # full ``nn.Module`` at ``ema.ema_model``, so we just + # pass that instead of swap-and-restore. + target_policy = ( + ema.ema_model + if (ema is not None and cfg.ema.use_for_wandb_examples) + else accelerator.unwrap_model(policy) + ) + wandb_logger.log_training_examples( + batch=batch, + step=step, + camera_keys=list(dataset.meta.camera_keys), + n_samples=cfg.wandb.log_examples_n, + policy=target_policy, + predict_actions=cfg.wandb.log_examples_predict_actions, ) - with wandb_ema_ctx: - wandb_logger.log_training_examples( - batch=batch, - step=step, - camera_keys=list(dataset.meta.camera_keys), - n_samples=cfg.wandb.log_examples_n, - policy=accelerator.unwrap_model(policy), - predict_actions=cfg.wandb.log_examples_predict_actions, - ) except Exception as exc: # noqa: BLE001 logging.warning("wandb log_training_examples failed: %s", exc) @@ -857,11 +863,14 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): update_last_checkpoint(checkpoint_dir) # Save the EMA shadow alongside the training state so a # resumed run picks up exactly where the live EMA left off. + # ``ema-pytorch.state_dict()`` returns the full shadow + # nn.Module's state dict + step/initted buffers; saved as + # .pt (the rest of training_state mixes formats already). if ema is not None: try: - ema.save( - checkpoint_dir / "training_state" / "ema_state.safetensors" - ) + ema_path = checkpoint_dir / "training_state" / "ema_state.pt" + ema_path.parent.mkdir(parents=True, exist_ok=True) + torch.save(ema.state_dict(), ema_path) except Exception as exc: # noqa: BLE001 logging.warning("Failed to save EMA shadow: %s", exc) if wandb_logger: @@ -873,19 +882,20 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): if is_main_process: step_id = get_step_identifier(step, cfg.steps) logging.info(f"Eval policy at step {step}") - # Optionally swap in EMA weights for eval — standard - # practice for diffusion-style policies (~1–3% lift on - # closed-loop success). Live weights are restored byte- - # for-byte on context exit. - eval_ema_ctx = ( - ema.apply_to(accelerator.unwrap_model(policy)) - if ema is not None and cfg.ema.use_for_eval - else nullcontext() + # Use the EMA shadow model for eval when enabled — + # standard practice for diffusion-style policies (~1–3% + # lift on closed-loop success). ``ema.ema_model`` is a + # full nn.Module clone, so we just pass it through; no + # swap/restore on the live policy needed. + eval_target_policy = ( + ema.ema_model + if (ema is not None and cfg.ema.use_for_eval) + else accelerator.unwrap_model(policy) ) - with eval_ema_ctx, torch.no_grad(), accelerator.autocast(): + with torch.no_grad(), accelerator.autocast(): eval_info = eval_policy_all( envs=eval_env, # dict[suite][task_id] -> vec_env - policy=accelerator.unwrap_model(policy), + policy=eval_target_policy, env_preprocessor=env_preprocessor, env_postprocessor=env_postprocessor, preprocessor=preprocessor, diff --git a/src/lerobot/utils/ema.py b/src/lerobot/utils/ema.py deleted file mode 100644 index c0ece76d0..000000000 --- a/src/lerobot/utils/ema.py +++ /dev/null @@ -1,250 +0,0 @@ -# Copyright 2026 The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Exponential Moving Average of model parameters for training stability. - -Maintains a shadow copy of every trainable parameter, updated after each -optimizer step:: - - θ_ema ← β · θ_ema + (1 - β) · θ_live - -At eval / inference / final checkpoint, use ``θ_ema`` instead of -``θ_live``. For diffusion / flow-matching policies, averaging late- -training oscillations yields a smoother model that generalises -substantially better at inference — see Chi et al. 2023 (Diffusion -Policy §V.D, β=0.75), Ho et al. 2020 (DDPM appendix). For VLAs with a -flow-matching action expert the same logic applies: flow gradients have -high variance per sample (different noise levels in the same batch), -so EMA smooths over that variance. - -Cost: 1× model parameters in fp32 (~13 GB for pi052's 3.3B params), -plus one elementwise update per training step (~1% of step time). -""" - -from __future__ import annotations - -import json -from contextlib import contextmanager -from pathlib import Path -from typing import Any, Iterator - -import torch -import torch.nn as nn -from safetensors.torch import load_file, save_file - - -class ModelEMA: - """Exponential moving average of trainable model parameters. - - Args: - model: The live model whose parameter shapes/names define the - shadow. Only parameters with ``requires_grad=True`` are - tracked. Buffers are intentionally NOT tracked (LayerNorm - running stats, RoPE caches, etc.) — they are updated in - ``train()`` mode regardless of which weights we apply. - decay: Target EMA decay (``β`` in - ``θ_ema ← β·θ_ema + (1-β)·θ_live``). Typical values: - - * ``0.999`` — averages roughly the last 1000 steps. Standard - for diffusion-style policies and the default here. - * ``0.75`` — very fast EMA, used by Diffusion Policy (Chi - et al. 2023). Useful when training is short or noisy. - * ``0.9999`` — very slow EMA, used in image classification - for very long runs. - warmup_steps: If > 0, ramp the effective decay from a low value - up to ``decay`` over the first ``warmup_steps`` updates as - ``min(decay, (1 + n) / (10 + n))``. Lets the EMA track - rapid early-training changes before settling on the target. - device: Where to keep the shadow parameters. ``None`` keeps each - shadow on its parameter's device (good for FSDP / multi-GPU). - Pass an explicit device to relocate everything (e.g. ``"cpu"`` - to free GPU memory at the cost of slower updates). - """ - - def __init__( - self, - model: nn.Module, - *, - decay: float = 0.999, - warmup_steps: int = 0, - device: torch.device | str | None = None, - ) -> None: - if not 0.0 < decay < 1.0: - raise ValueError(f"decay must be in (0, 1), got {decay}") - self.decay = float(decay) - self.warmup_steps = int(warmup_steps) - self.num_updates = 0 - self.device = torch.device(device) if device is not None else None - - # fp32 shadow — small EMA updates lose precision in bf16. - self.shadow: dict[str, torch.Tensor] = {} - for name, p in model.named_parameters(): - if not p.requires_grad: - continue - shadow_p = p.detach().clone().float() - if self.device is not None: - shadow_p = shadow_p.to(self.device) - self.shadow[name] = shadow_p - - # ------------------------------------------------------------------ - # Core update - # ------------------------------------------------------------------ - - def _effective_decay(self) -> float: - if self.warmup_steps <= 0 or self.num_updates >= self.warmup_steps: - return self.decay - # Standard EMA warmup (timm / diffusers convention): grows - # 0.09, 0.16, 0.23, ... and saturates at ``decay``. - return min(self.decay, (1.0 + self.num_updates) / (10.0 + self.num_updates)) - - @torch.no_grad() - def update(self, model: nn.Module) -> float: - """Pull one update from the live model into the shadow. - - Returns the effective decay used this step (useful to log during - warmup, when the value differs from ``self.decay``). - """ - self.num_updates += 1 - beta = self._effective_decay() - one_minus_beta = 1.0 - beta - for name, p in model.named_parameters(): - if not p.requires_grad: - continue - shadow = self.shadow.get(name) - if shadow is None: - # New parameter appeared mid-training — seed it. - shadow = p.detach().clone().float() - if self.device is not None: - shadow = shadow.to(self.device) - self.shadow[name] = shadow - continue - # In-place fused: shadow ← β · shadow + (1 - β) · p - shadow.mul_(beta).add_( - p.detach().to(shadow.device, dtype=torch.float32), - alpha=one_minus_beta, - ) - return beta - - # ------------------------------------------------------------------ - # Applying the EMA to the live model - # ------------------------------------------------------------------ - - @torch.no_grad() - def copy_to(self, model: nn.Module) -> None: - """Overwrite the live model's parameters with the EMA shadow. - - In-place and **irreversible** — the previous live weights are - lost. Use this only at the very end of training when you want - the EMA to *be* the final saved policy. For temporary swaps - (e.g. during eval), use :meth:`apply_to`. - """ - for name, p in model.named_parameters(): - shadow = self.shadow.get(name) - if shadow is not None: - p.data.copy_(shadow.to(p.device, dtype=p.dtype)) - - @contextmanager - def apply_to(self, model: nn.Module) -> Iterator[None]: - """Temporarily swap the live model's weights with the EMA copy. - - On exit, the original live weights are restored byte-for-byte - (we keep a backup clone of every tracked parameter inside the - context). Use this around eval / sample-logging without - disturbing the live training state:: - - with ema.apply_to(policy): - eval_metrics = evaluate(policy) - # policy is back to its pre-eval state here. - """ - backup: dict[str, torch.Tensor] = {} - for name, p in model.named_parameters(): - if name in self.shadow: - backup[name] = p.detach().clone() - p.data.copy_(self.shadow[name].to(p.device, dtype=p.dtype)) - try: - yield - finally: - for name, p in model.named_parameters(): - if name in backup: - p.data.copy_(backup[name].to(p.device, dtype=p.dtype)) - - # ------------------------------------------------------------------ - # Checkpointing - # ------------------------------------------------------------------ - - def state_dict(self) -> dict[str, Any]: - return { - "shadow": {k: v.detach().cpu() for k, v in self.shadow.items()}, - "num_updates": self.num_updates, - "decay": self.decay, - "warmup_steps": self.warmup_steps, - } - - def load_state_dict(self, state: dict[str, Any]) -> None: - self.decay = float(state["decay"]) - self.warmup_steps = int(state["warmup_steps"]) - self.num_updates = int(state["num_updates"]) - new_shadow: dict[str, torch.Tensor] = {} - for k, v in state["shadow"].items(): - t = v.detach() - if self.device is not None: - t = t.to(self.device) - new_shadow[k] = t.float() - self.shadow = new_shadow - - def save(self, path: Path | str) -> None: - """Save the shadow as safetensors + a tiny JSON sidecar with metadata. - - Sidecar lives at ``.json`` and stores ``num_updates``, - ``decay``, ``warmup_steps`` — enough to resume exact EMA state. - """ - path = Path(path) - path.parent.mkdir(parents=True, exist_ok=True) - save_file( - {k: v.detach().cpu().contiguous() for k, v in self.shadow.items()}, - str(path), - ) - meta = { - "num_updates": self.num_updates, - "decay": self.decay, - "warmup_steps": self.warmup_steps, - } - path.with_suffix(path.suffix + ".json").write_text(json.dumps(meta, indent=2)) - - @classmethod - def load_from_file( - cls, - model: nn.Module, - path: Path | str, - *, - device: torch.device | str | None = None, - ) -> "ModelEMA": - """Reconstruct a ``ModelEMA`` from a previously-saved safetensors + sidecar pair.""" - path = Path(path) - meta_path = path.with_suffix(path.suffix + ".json") - meta = json.loads(meta_path.read_text()) if meta_path.exists() else {} - ema = cls( - model, - decay=float(meta.get("decay", 0.999)), - warmup_steps=int(meta.get("warmup_steps", 0)), - device=device, - ) - shadow = load_file(str(path)) - target_device = ema.device - ema.shadow = { - k: (v.to(target_device) if target_device is not None else v).float() - for k, v in shadow.items() - } - ema.num_updates = int(meta.get("num_updates", 0)) - return ema