fix(train): unwrap DDP policy in debug_text_predictions hook

At training time the policy is wrapped by Accelerator/DDP into a
.module attribute and custom methods are NOT proxied through the
wrapper, so ``hasattr(policy, "debug_text_predictions")`` was False
and the periodic dump was silently no-op'ing. Walk through .module
indirection to reach the raw PI052Policy that defines the method.

Also surface why the dump didn't fire (no method / empty supervised
positions / generation error) so users can see what's blocking it
instead of staring at silence.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-21 13:41:20 +02:00
parent e98b6f726b
commit efa05f0ada
+22 -3
View File
@@ -170,15 +170,34 @@ def _print_debug_text_predictions(
per-sample token-accuracy summary — the cheapest "is text training
actually learning anything" signal.
"""
if not hasattr(policy, "debug_text_predictions"):
# Accelerator/DDP wraps the policy in a ``module`` attribute and
# doesn't proxy custom methods through, so a naive
# ``hasattr(policy, "debug_text_predictions")`` returns False on the
# wrapper — and the helper would silently no-op. Walk through any
# ``.module`` indirection (DDP, FSDP, ``accelerator.prepare`` wrappers)
# to reach the raw policy that actually defines the method.
inner = policy
while hasattr(inner, "module") and not hasattr(inner, "debug_text_predictions"):
inner = inner.module
if not hasattr(inner, "debug_text_predictions"):
logging.warning(
"LEROBOT_DEBUG_PREDS_EVERY set but policy %s has no "
"debug_text_predictions method — skipping dump.",
type(inner).__name__,
)
return
try:
debug = policy.debug_text_predictions(batch, max_samples=n_samples)
debug = inner.debug_text_predictions(batch, max_samples=n_samples)
except Exception as exc: # noqa: BLE001
logging.warning("debug_text_predictions failed: %s", exc)
logging.warning("debug_text_predictions failed: %s", exc, exc_info=True)
return
if not debug:
logging.warning(
"debug_text_predictions returned no supervised samples — "
"current batch has no text labels."
)
return
policy = inner # used below for select_message-style decoding parity
# Build a tokenizer for decoding — match training side exactly.
try: