mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 14:09:47 +00:00
fix(train): unwrap DDP policy in debug_text_predictions hook
At training time the policy is wrapped by Accelerator/DDP into a .module attribute and custom methods are NOT proxied through the wrapper, so ``hasattr(policy, "debug_text_predictions")`` was False and the periodic dump was silently no-op'ing. Walk through .module indirection to reach the raw PI052Policy that defines the method. Also surface why the dump didn't fire (no method / empty supervised positions / generation error) so users can see what's blocking it instead of staring at silence. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -170,15 +170,34 @@ def _print_debug_text_predictions(
|
|||||||
per-sample token-accuracy summary — the cheapest "is text training
|
per-sample token-accuracy summary — the cheapest "is text training
|
||||||
actually learning anything" signal.
|
actually learning anything" signal.
|
||||||
"""
|
"""
|
||||||
if not hasattr(policy, "debug_text_predictions"):
|
# Accelerator/DDP wraps the policy in a ``module`` attribute and
|
||||||
|
# doesn't proxy custom methods through, so a naive
|
||||||
|
# ``hasattr(policy, "debug_text_predictions")`` returns False on the
|
||||||
|
# wrapper — and the helper would silently no-op. Walk through any
|
||||||
|
# ``.module`` indirection (DDP, FSDP, ``accelerator.prepare`` wrappers)
|
||||||
|
# to reach the raw policy that actually defines the method.
|
||||||
|
inner = policy
|
||||||
|
while hasattr(inner, "module") and not hasattr(inner, "debug_text_predictions"):
|
||||||
|
inner = inner.module
|
||||||
|
if not hasattr(inner, "debug_text_predictions"):
|
||||||
|
logging.warning(
|
||||||
|
"LEROBOT_DEBUG_PREDS_EVERY set but policy %s has no "
|
||||||
|
"debug_text_predictions method — skipping dump.",
|
||||||
|
type(inner).__name__,
|
||||||
|
)
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
debug = policy.debug_text_predictions(batch, max_samples=n_samples)
|
debug = inner.debug_text_predictions(batch, max_samples=n_samples)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
logging.warning("debug_text_predictions failed: %s", exc)
|
logging.warning("debug_text_predictions failed: %s", exc, exc_info=True)
|
||||||
return
|
return
|
||||||
if not debug:
|
if not debug:
|
||||||
|
logging.warning(
|
||||||
|
"debug_text_predictions returned no supervised samples — "
|
||||||
|
"current batch has no text labels."
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
policy = inner # used below for select_message-style decoding parity
|
||||||
|
|
||||||
# Build a tokenizer for decoding — match training side exactly.
|
# Build a tokenizer for decoding — match training side exactly.
|
||||||
try:
|
try:
|
||||||
|
|||||||
Reference in New Issue
Block a user