mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
fix(train): unwrap DDP policy in debug_text_predictions hook
At training time the policy is wrapped by Accelerator/DDP into a .module attribute and custom methods are NOT proxied through the wrapper, so ``hasattr(policy, "debug_text_predictions")`` was False and the periodic dump was silently no-op'ing. Walk through .module indirection to reach the raw PI052Policy that defines the method. Also surface why the dump didn't fire (no method / empty supervised positions / generation error) so users can see what's blocking it instead of staring at silence. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -170,15 +170,34 @@ def _print_debug_text_predictions(
|
||||
per-sample token-accuracy summary — the cheapest "is text training
|
||||
actually learning anything" signal.
|
||||
"""
|
||||
if not hasattr(policy, "debug_text_predictions"):
|
||||
# Accelerator/DDP wraps the policy in a ``module`` attribute and
|
||||
# doesn't proxy custom methods through, so a naive
|
||||
# ``hasattr(policy, "debug_text_predictions")`` returns False on the
|
||||
# wrapper — and the helper would silently no-op. Walk through any
|
||||
# ``.module`` indirection (DDP, FSDP, ``accelerator.prepare`` wrappers)
|
||||
# to reach the raw policy that actually defines the method.
|
||||
inner = policy
|
||||
while hasattr(inner, "module") and not hasattr(inner, "debug_text_predictions"):
|
||||
inner = inner.module
|
||||
if not hasattr(inner, "debug_text_predictions"):
|
||||
logging.warning(
|
||||
"LEROBOT_DEBUG_PREDS_EVERY set but policy %s has no "
|
||||
"debug_text_predictions method — skipping dump.",
|
||||
type(inner).__name__,
|
||||
)
|
||||
return
|
||||
try:
|
||||
debug = policy.debug_text_predictions(batch, max_samples=n_samples)
|
||||
debug = inner.debug_text_predictions(batch, max_samples=n_samples)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logging.warning("debug_text_predictions failed: %s", exc)
|
||||
logging.warning("debug_text_predictions failed: %s", exc, exc_info=True)
|
||||
return
|
||||
if not debug:
|
||||
logging.warning(
|
||||
"debug_text_predictions returned no supervised samples — "
|
||||
"current batch has no text labels."
|
||||
)
|
||||
return
|
||||
policy = inner # used below for select_message-style decoding parity
|
||||
|
||||
# Build a tokenizer for decoding — match training side exactly.
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user