feat(train): periodic LM-head prediction dump for live debugging

Adds an opt-in diagnostic that, every N training steps, dumps 5 batch
samples plus the LM head's argmax prediction at every supervised
position alongside the label and a ✓/✗ marker — the cheapest signal
for "is text training actually learning what we expect, or collapsing
to a fixed token". Refills the recipe-sample dump budget on the same
cadence so the raw input shapes are also re-dumped.

Opt in via env var:
  LEROBOT_DEBUG_PREDS_EVERY=1000 lerobot-train ...

PI052 implements ``debug_text_predictions`` (mirrors the text-loss
forward but returns argmax instead of CE); other policies are silently
skipped. The dump runs in eval() mode under no_grad, slicing the
current batch to N samples — no extra data fetch, no train-state
mutation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-21 12:23:05 +02:00
parent 86ecd4bc2e
commit f7747d02a9
2 changed files with 190 additions and 0 deletions
@@ -782,6 +782,90 @@ class PI052Policy(PI05Policy):
return text_loss, fast_loss
# ------------------------------------------------------------------
# Diagnostic: forward + argmax for supervised text positions
# ------------------------------------------------------------------
@torch.no_grad()
def debug_text_predictions(
self, batch: dict[str, Tensor], max_samples: int = 5
) -> dict[str, Tensor]:
"""Run the text-loss forward but return argmax predictions instead of CE.
Lets a periodic training-loop hook compare what the LM head emits
right now against what it *should* emit at every supervised
position — the cheapest "is text training actually working"
diagnostic. Returns CPU tensors keyed by ``input_ids``,
``attention_mask``, ``labels``, ``predictions``; predictions are
aligned with input positions (``predictions[t]`` is the head's
argmax after seeing ``input_ids[:t+1]``, so it should match
``input_ids[t+1]`` for next-token prediction). Returns ``{}``
when the batch has no supervised text positions.
"""
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
text_labels = batch.get("text_labels")
if text_labels is None or not bool((text_labels != -100).any().item()):
return {}
was_training = self.training
self.eval()
try:
n = min(max_samples, int(text_labels.shape[0]))
sub: dict[str, Any] = {
OBS_LANGUAGE_TOKENS: batch[OBS_LANGUAGE_TOKENS][:n],
OBS_LANGUAGE_ATTENTION_MASK: batch[OBS_LANGUAGE_ATTENTION_MASK][:n],
}
for k, v in batch.items():
if isinstance(k, str) and k.startswith("observation.images.") and torch.is_tensor(v):
sub[k] = v[:n]
sub_labels = text_labels[:n]
images, img_masks = self._preprocess_images(sub)
lang_tokens = sub[OBS_LANGUAGE_TOKENS]
lang_masks = sub[OBS_LANGUAGE_ATTENTION_MASK]
prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
lang_start = prefix_embs.shape[1] - sub_labels.shape[1]
if lang_start >= 0:
prefix_att = _mark_target_span_causal(
prefix_att, sub_labels, lang_start, prefix_embs.shape[1]
)
att_2d = make_att_2d_masks(prefix_pad, prefix_att)
position_ids = torch.cumsum(prefix_pad, dim=1) - 1
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d)
backbone = self.model.paligemma_with_expert
backbone_dtype = (
backbone.paligemma.model.language_model.layers[0]
.self_attn.q_proj.weight.dtype
)
if att_2d_4d.dtype != backbone_dtype:
att_2d_4d = att_2d_4d.to(dtype=backbone_dtype)
(vlm_out, _), _ = backbone.forward(
attention_mask=att_2d_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=False,
)
text_hidden = vlm_out[:, -sub_labels.shape[1]:, :]
lm_head = backbone.paligemma.lm_head
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
preds = text_logits.argmax(dim=-1)
return {
"input_ids": lang_tokens.detach().cpu(),
"attention_mask": lang_masks.detach().cpu(),
"labels": sub_labels.detach().cpu(),
"predictions": preds.detach().cpu(),
}
finally:
if was_training:
self.train()
# ------------------------------------------------------------------
# select_message — AR text generation at inference
# ------------------------------------------------------------------
+106
View File
@@ -20,6 +20,7 @@ Requires: pip install 'lerobot[training]' (includes dataset + accelerate + wand
import dataclasses
import logging
import os
import time
from contextlib import nullcontext
from pprint import pformat
@@ -156,6 +157,90 @@ def update_policy(
return train_metrics, output_dict
def _print_debug_text_predictions(
policy: Any, batch: dict[str, Any], step: int, n_samples: int = 5
) -> None:
"""Forward the current batch and print head-argmax vs label per supervised position.
Opt-in via ``LEROBOT_DEBUG_PREDS_EVERY=<step_interval>``. Only the
policy types that expose ``debug_text_predictions`` participate
(currently PI052); others are silently skipped. Pretty-prints up to
``n_samples`` samples from the current batch, showing the prompt,
every supervised position's (label, prediction, ✓/✗), and a
per-sample token-accuracy summary the cheapest "is text training
actually learning anything" signal.
"""
if not hasattr(policy, "debug_text_predictions"):
return
try:
debug = policy.debug_text_predictions(batch, max_samples=n_samples)
except Exception as exc: # noqa: BLE001
logging.warning("debug_text_predictions failed: %s", exc)
return
if not debug:
return
# Build a tokenizer for decoding — match training side exactly.
try:
from transformers import AutoTokenizer # noqa: PLC0415
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: PLC0415
register_paligemma_loc_tokens,
)
tok_name = (
getattr(policy.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224"
)
tokenizer = register_paligemma_loc_tokens(AutoTokenizer.from_pretrained(tok_name))
except Exception as exc: # noqa: BLE001
logging.warning("debug preds: tokenizer load failed: %s", exc)
return
ids = debug["input_ids"]
labels = debug["labels"]
preds = debug["predictions"]
attn = debug["attention_mask"]
n = ids.shape[0]
print(
f"\n========== STEP {step} DEBUG PREDICTIONS ({n} samples) ==========",
flush=True,
)
for s in range(n):
a = attn[s].tolist()
real = sum(a)
sid = ids[s].tolist()
sl = labels[s].tolist()
sp = preds[s].tolist()
prompt = tokenizer.decode(sid[:real], skip_special_tokens=False)
print(f"\n --- sample {s + 1}/{n} ---", flush=True)
print(f" prompt: {prompt!r}", flush=True)
n_sup = n_ok = 0
rows: list[str] = []
# CE shift: pred[t] predicts label[t+1]. Iterate supervised label
# positions (i = t+1) and align with prediction at t = i-1.
for i in range(1, real):
label = sl[i]
if label == -100:
continue
n_sup += 1
pred = sp[i - 1]
ok = label == pred
n_ok += int(ok)
lbl_str = tokenizer.decode([label]) if 0 <= label < tokenizer.vocab_size + 2048 else "<oob>"
pred_str = tokenizer.decode([pred]) if 0 <= pred < tokenizer.vocab_size + 2048 else "<oob>"
mark = "" if ok else ""
rows.append(
f" pos {i - 1:3d}{i:3d}: label {label:6d} {lbl_str!r:20s} | "
f"pred {pred:6d} {pred_str!r:20s} {mark}"
)
for r in rows:
print(r, flush=True)
acc = n_ok / max(n_sup, 1)
print(f" token accuracy: {n_ok}/{n_sup} = {acc:.1%}", flush=True)
print("=" * 60 + "\n", flush=True)
def _build_vqa_oversample_weights(dataset: Any, target_fraction: float) -> "torch.Tensor | None":
"""Build per-frame sampling weights that oversample VQA-annotated frames.
@@ -542,6 +627,27 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
# Optional periodic head-prediction dump for the LM head:
# ``LEROBOT_DEBUG_PREDS_EVERY=1000`` prints 5 samples + per-token
# (label, argmax, ✓/✗) every 1000 steps. Cheap diagnostic to see
# whether the text head is actually learning what we expect, vs
# collapsing to a fixed token. Refilling the recipe-sample dump
# budget at the same cadence also redumps the raw input shapes.
_debug_preds_every = int(os.environ.get("LEROBOT_DEBUG_PREDS_EVERY", "0"))
if (
_debug_preds_every > 0
and step % _debug_preds_every == 0
and is_main_process
):
try:
from lerobot.policies.pi052 import text_processor_pi052 as _tp # noqa: PLC0415
_tp._DUMPED_SO_FAR = 0
_tp._DUMP_BUDGET = max(_tp._DUMP_BUDGET, 5)
except Exception: # noqa: BLE001
pass
_print_debug_text_predictions(policy, batch, step, n_samples=5)
if is_log_step:
logging.info(train_tracker)
if wandb_logger: