From efa05f0adaa4e39bacdc547b5fa8ad96a91a698d Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 21 May 2026 13:41:20 +0200 Subject: [PATCH] fix(train): unwrap DDP policy in debug_text_predictions hook At training time the policy is wrapped by Accelerator/DDP into a .module attribute and custom methods are NOT proxied through the wrapper, so ``hasattr(policy, "debug_text_predictions")`` was False and the periodic dump was silently no-op'ing. Walk through .module indirection to reach the raw PI052Policy that defines the method. Also surface why the dump didn't fire (no method / empty supervised positions / generation error) so users can see what's blocking it instead of staring at silence. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lerobot/scripts/lerobot_train.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 3a64912ed..9b9cf659e 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -170,15 +170,34 @@ def _print_debug_text_predictions( per-sample token-accuracy summary — the cheapest "is text training actually learning anything" signal. """ - if not hasattr(policy, "debug_text_predictions"): + # Accelerator/DDP wraps the policy in a ``module`` attribute and + # doesn't proxy custom methods through, so a naive + # ``hasattr(policy, "debug_text_predictions")`` returns False on the + # wrapper — and the helper would silently no-op. Walk through any + # ``.module`` indirection (DDP, FSDP, ``accelerator.prepare`` wrappers) + # to reach the raw policy that actually defines the method. + inner = policy + while hasattr(inner, "module") and not hasattr(inner, "debug_text_predictions"): + inner = inner.module + if not hasattr(inner, "debug_text_predictions"): + logging.warning( + "LEROBOT_DEBUG_PREDS_EVERY set but policy %s has no " + "debug_text_predictions method — skipping dump.", + type(inner).__name__, + ) return try: - debug = policy.debug_text_predictions(batch, max_samples=n_samples) + debug = inner.debug_text_predictions(batch, max_samples=n_samples) except Exception as exc: # noqa: BLE001 - logging.warning("debug_text_predictions failed: %s", exc) + logging.warning("debug_text_predictions failed: %s", exc, exc_info=True) return if not debug: + logging.warning( + "debug_text_predictions returned no supervised samples — " + "current batch has no text labels." + ) return + policy = inner # used below for select_message-style decoding parity # Build a tokenizer for decoding — match training side exactly. try: