feat(train): periodic LM-head prediction dump for live debugging

Adds an opt-in diagnostic that, every N training steps, dumps 5 batch
samples plus the LM head's argmax prediction at every supervised
position alongside the label and a ✓/✗ marker — the cheapest signal
for "is text training actually learning what we expect, or collapsing
to a fixed token". Refills the recipe-sample dump budget on the same
cadence so the raw input shapes are also re-dumped.

Opt in via env var:
  LEROBOT_DEBUG_PREDS_EVERY=1000 lerobot-train ...

PI052 implements ``debug_text_predictions`` (mirrors the text-loss
forward but returns argmax instead of CE); other policies are silently
skipped. The dump runs in eval() mode under no_grad, slicing the
current batch to N samples — no extra data fetch, no train-state
mutation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-21 12:23:05 +02:00
parent 86ecd4bc2e
commit f7747d02a9
2 changed files with 190 additions and 0 deletions
@@ -782,6 +782,90 @@ class PI052Policy(PI05Policy):
return text_loss, fast_loss return text_loss, fast_loss
# ------------------------------------------------------------------
# Diagnostic: forward + argmax for supervised text positions
# ------------------------------------------------------------------
@torch.no_grad()
def debug_text_predictions(
self, batch: dict[str, Tensor], max_samples: int = 5
) -> dict[str, Tensor]:
"""Run the text-loss forward but return argmax predictions instead of CE.
Lets a periodic training-loop hook compare what the LM head emits
right now against what it *should* emit at every supervised
position — the cheapest "is text training actually working"
diagnostic. Returns CPU tensors keyed by ``input_ids``,
``attention_mask``, ``labels``, ``predictions``; predictions are
aligned with input positions (``predictions[t]`` is the head's
argmax after seeing ``input_ids[:t+1]``, so it should match
``input_ids[t+1]`` for next-token prediction). Returns ``{}``
when the batch has no supervised text positions.
"""
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
text_labels = batch.get("text_labels")
if text_labels is None or not bool((text_labels != -100).any().item()):
return {}
was_training = self.training
self.eval()
try:
n = min(max_samples, int(text_labels.shape[0]))
sub: dict[str, Any] = {
OBS_LANGUAGE_TOKENS: batch[OBS_LANGUAGE_TOKENS][:n],
OBS_LANGUAGE_ATTENTION_MASK: batch[OBS_LANGUAGE_ATTENTION_MASK][:n],
}
for k, v in batch.items():
if isinstance(k, str) and k.startswith("observation.images.") and torch.is_tensor(v):
sub[k] = v[:n]
sub_labels = text_labels[:n]
images, img_masks = self._preprocess_images(sub)
lang_tokens = sub[OBS_LANGUAGE_TOKENS]
lang_masks = sub[OBS_LANGUAGE_ATTENTION_MASK]
prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix(
images, img_masks, lang_tokens, lang_masks
)
lang_start = prefix_embs.shape[1] - sub_labels.shape[1]
if lang_start >= 0:
prefix_att = _mark_target_span_causal(
prefix_att, sub_labels, lang_start, prefix_embs.shape[1]
)
att_2d = make_att_2d_masks(prefix_pad, prefix_att)
position_ids = torch.cumsum(prefix_pad, dim=1) - 1
att_2d_4d = self.model._prepare_attention_masks_4d(att_2d)
backbone = self.model.paligemma_with_expert
backbone_dtype = (
backbone.paligemma.model.language_model.layers[0]
.self_attn.q_proj.weight.dtype
)
if att_2d_4d.dtype != backbone_dtype:
att_2d_4d = att_2d_4d.to(dtype=backbone_dtype)
(vlm_out, _), _ = backbone.forward(
attention_mask=att_2d_4d,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, None],
use_cache=False,
)
text_hidden = vlm_out[:, -sub_labels.shape[1]:, :]
lm_head = backbone.paligemma.lm_head
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
preds = text_logits.argmax(dim=-1)
return {
"input_ids": lang_tokens.detach().cpu(),
"attention_mask": lang_masks.detach().cpu(),
"labels": sub_labels.detach().cpu(),
"predictions": preds.detach().cpu(),
}
finally:
if was_training:
self.train()
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# select_message — AR text generation at inference # select_message — AR text generation at inference
# ------------------------------------------------------------------ # ------------------------------------------------------------------
+106
View File
@@ -20,6 +20,7 @@ Requires: pip install 'lerobot[training]' (includes dataset + accelerate + wand
import dataclasses import dataclasses
import logging import logging
import os
import time import time
from contextlib import nullcontext from contextlib import nullcontext
from pprint import pformat from pprint import pformat
@@ -156,6 +157,90 @@ def update_policy(
return train_metrics, output_dict return train_metrics, output_dict
def _print_debug_text_predictions(
policy: Any, batch: dict[str, Any], step: int, n_samples: int = 5
) -> None:
"""Forward the current batch and print head-argmax vs label per supervised position.
Opt-in via ``LEROBOT_DEBUG_PREDS_EVERY=<step_interval>``. Only the
policy types that expose ``debug_text_predictions`` participate
(currently PI052); others are silently skipped. Pretty-prints up to
``n_samples`` samples from the current batch, showing the prompt,
every supervised position's (label, prediction, ✓/✗), and a
per-sample token-accuracy summary the cheapest "is text training
actually learning anything" signal.
"""
if not hasattr(policy, "debug_text_predictions"):
return
try:
debug = policy.debug_text_predictions(batch, max_samples=n_samples)
except Exception as exc: # noqa: BLE001
logging.warning("debug_text_predictions failed: %s", exc)
return
if not debug:
return
# Build a tokenizer for decoding — match training side exactly.
try:
from transformers import AutoTokenizer # noqa: PLC0415
from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: PLC0415
register_paligemma_loc_tokens,
)
tok_name = (
getattr(policy.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224"
)
tokenizer = register_paligemma_loc_tokens(AutoTokenizer.from_pretrained(tok_name))
except Exception as exc: # noqa: BLE001
logging.warning("debug preds: tokenizer load failed: %s", exc)
return
ids = debug["input_ids"]
labels = debug["labels"]
preds = debug["predictions"]
attn = debug["attention_mask"]
n = ids.shape[0]
print(
f"\n========== STEP {step} DEBUG PREDICTIONS ({n} samples) ==========",
flush=True,
)
for s in range(n):
a = attn[s].tolist()
real = sum(a)
sid = ids[s].tolist()
sl = labels[s].tolist()
sp = preds[s].tolist()
prompt = tokenizer.decode(sid[:real], skip_special_tokens=False)
print(f"\n --- sample {s + 1}/{n} ---", flush=True)
print(f" prompt: {prompt!r}", flush=True)
n_sup = n_ok = 0
rows: list[str] = []
# CE shift: pred[t] predicts label[t+1]. Iterate supervised label
# positions (i = t+1) and align with prediction at t = i-1.
for i in range(1, real):
label = sl[i]
if label == -100:
continue
n_sup += 1
pred = sp[i - 1]
ok = label == pred
n_ok += int(ok)
lbl_str = tokenizer.decode([label]) if 0 <= label < tokenizer.vocab_size + 2048 else "<oob>"
pred_str = tokenizer.decode([pred]) if 0 <= pred < tokenizer.vocab_size + 2048 else "<oob>"
mark = "" if ok else ""
rows.append(
f" pos {i - 1:3d}{i:3d}: label {label:6d} {lbl_str!r:20s} | "
f"pred {pred:6d} {pred_str!r:20s} {mark}"
)
for r in rows:
print(r, flush=True)
acc = n_ok / max(n_sup, 1)
print(f" token accuracy: {n_ok}/{n_sup} = {acc:.1%}", flush=True)
print("=" * 60 + "\n", flush=True)
def _build_vqa_oversample_weights(dataset: Any, target_fraction: float) -> "torch.Tensor | None": def _build_vqa_oversample_weights(dataset: Any, target_fraction: float) -> "torch.Tensor | None":
"""Build per-frame sampling weights that oversample VQA-annotated frames. """Build per-frame sampling weights that oversample VQA-annotated frames.
@@ -542,6 +627,27 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0
# Optional periodic head-prediction dump for the LM head:
# ``LEROBOT_DEBUG_PREDS_EVERY=1000`` prints 5 samples + per-token
# (label, argmax, ✓/✗) every 1000 steps. Cheap diagnostic to see
# whether the text head is actually learning what we expect, vs
# collapsing to a fixed token. Refilling the recipe-sample dump
# budget at the same cadence also redumps the raw input shapes.
_debug_preds_every = int(os.environ.get("LEROBOT_DEBUG_PREDS_EVERY", "0"))
if (
_debug_preds_every > 0
and step % _debug_preds_every == 0
and is_main_process
):
try:
from lerobot.policies.pi052 import text_processor_pi052 as _tp # noqa: PLC0415
_tp._DUMPED_SO_FAR = 0
_tp._DUMP_BUDGET = max(_tp._DUMP_BUDGET, 5)
except Exception: # noqa: BLE001
pass
_print_debug_text_predictions(policy, batch, step, n_samples=5)
if is_log_step: if is_log_step:
logging.info(train_tracker) logging.info(train_tracker)
if wandb_logger: if wandb_logger: