From f7747d02a9f4b2cfd8cf1cead3a8a5482f52e528 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 21 May 2026 12:23:05 +0200 Subject: [PATCH] feat(train): periodic LM-head prediction dump for live debugging MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds an opt-in diagnostic that, every N training steps, dumps 5 batch samples plus the LM head's argmax prediction at every supervised position alongside the label and a ✓/✗ marker — the cheapest signal for "is text training actually learning what we expect, or collapsing to a fixed token". Refills the recipe-sample dump budget on the same cadence so the raw input shapes are also re-dumped. Opt in via env var: LEROBOT_DEBUG_PREDS_EVERY=1000 lerobot-train ... PI052 implements ``debug_text_predictions`` (mirrors the text-loss forward but returns argmax instead of CE); other policies are silently skipped. The dump runs in eval() mode under no_grad, slicing the current batch to N samples — no extra data fetch, no train-state mutation. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/lerobot/policies/pi052/modeling_pi052.py | 84 +++++++++++++++ src/lerobot/scripts/lerobot_train.py | 106 +++++++++++++++++++ 2 files changed, 190 insertions(+) diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index e587eca4e..555ae7738 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -782,6 +782,90 @@ class PI052Policy(PI05Policy): return text_loss, fast_loss + # ------------------------------------------------------------------ + # Diagnostic: forward + argmax for supervised text positions + # ------------------------------------------------------------------ + + @torch.no_grad() + def debug_text_predictions( + self, batch: dict[str, Tensor], max_samples: int = 5 + ) -> dict[str, Tensor]: + """Run the text-loss forward but return argmax predictions instead of CE. + + Lets a periodic training-loop hook compare what the LM head emits + right now against what it *should* emit at every supervised + position — the cheapest "is text training actually working" + diagnostic. Returns CPU tensors keyed by ``input_ids``, + ``attention_mask``, ``labels``, ``predictions``; predictions are + aligned with input positions (``predictions[t]`` is the head's + argmax after seeing ``input_ids[:t+1]``, so it should match + ``input_ids[t+1]`` for next-token prediction). Returns ``{}`` + when the batch has no supervised text positions. + """ + from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415 + + text_labels = batch.get("text_labels") + if text_labels is None or not bool((text_labels != -100).any().item()): + return {} + + was_training = self.training + self.eval() + try: + n = min(max_samples, int(text_labels.shape[0])) + sub: dict[str, Any] = { + OBS_LANGUAGE_TOKENS: batch[OBS_LANGUAGE_TOKENS][:n], + OBS_LANGUAGE_ATTENTION_MASK: batch[OBS_LANGUAGE_ATTENTION_MASK][:n], + } + for k, v in batch.items(): + if isinstance(k, str) and k.startswith("observation.images.") and torch.is_tensor(v): + sub[k] = v[:n] + + sub_labels = text_labels[:n] + images, img_masks = self._preprocess_images(sub) + lang_tokens = sub[OBS_LANGUAGE_TOKENS] + lang_masks = sub[OBS_LANGUAGE_ATTENTION_MASK] + + prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix( + images, img_masks, lang_tokens, lang_masks + ) + lang_start = prefix_embs.shape[1] - sub_labels.shape[1] + if lang_start >= 0: + prefix_att = _mark_target_span_causal( + prefix_att, sub_labels, lang_start, prefix_embs.shape[1] + ) + + att_2d = make_att_2d_masks(prefix_pad, prefix_att) + position_ids = torch.cumsum(prefix_pad, dim=1) - 1 + att_2d_4d = self.model._prepare_attention_masks_4d(att_2d) + backbone = self.model.paligemma_with_expert + backbone_dtype = ( + backbone.paligemma.model.language_model.layers[0] + .self_attn.q_proj.weight.dtype + ) + if att_2d_4d.dtype != backbone_dtype: + att_2d_4d = att_2d_4d.to(dtype=backbone_dtype) + + (vlm_out, _), _ = backbone.forward( + attention_mask=att_2d_4d, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=[prefix_embs, None], + use_cache=False, + ) + text_hidden = vlm_out[:, -sub_labels.shape[1]:, :] + lm_head = backbone.paligemma.lm_head + text_logits = lm_head(text_hidden.to(lm_head.weight.dtype)) + preds = text_logits.argmax(dim=-1) + return { + "input_ids": lang_tokens.detach().cpu(), + "attention_mask": lang_masks.detach().cpu(), + "labels": sub_labels.detach().cpu(), + "predictions": preds.detach().cpu(), + } + finally: + if was_training: + self.train() + # ------------------------------------------------------------------ # select_message — AR text generation at inference # ------------------------------------------------------------------ diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 2c9b3ad56..4e13abff1 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -20,6 +20,7 @@ Requires: pip install 'lerobot[training]' (includes dataset + accelerate + wand import dataclasses import logging +import os import time from contextlib import nullcontext from pprint import pformat @@ -156,6 +157,90 @@ def update_policy( return train_metrics, output_dict +def _print_debug_text_predictions( + policy: Any, batch: dict[str, Any], step: int, n_samples: int = 5 +) -> None: + """Forward the current batch and print head-argmax vs label per supervised position. + + Opt-in via ``LEROBOT_DEBUG_PREDS_EVERY=``. Only the + policy types that expose ``debug_text_predictions`` participate + (currently PI052); others are silently skipped. Pretty-prints up to + ``n_samples`` samples from the current batch, showing the prompt, + every supervised position's (label, prediction, ✓/✗), and a + per-sample token-accuracy summary — the cheapest "is text training + actually learning anything" signal. + """ + if not hasattr(policy, "debug_text_predictions"): + return + try: + debug = policy.debug_text_predictions(batch, max_samples=n_samples) + except Exception as exc: # noqa: BLE001 + logging.warning("debug_text_predictions failed: %s", exc) + return + if not debug: + return + + # Build a tokenizer for decoding — match training side exactly. + try: + from transformers import AutoTokenizer # noqa: PLC0415 + + from lerobot.policies.pi052.text_processor_pi052 import ( # noqa: PLC0415 + register_paligemma_loc_tokens, + ) + + tok_name = ( + getattr(policy.config, "tokenizer_name", None) or "google/paligemma-3b-pt-224" + ) + tokenizer = register_paligemma_loc_tokens(AutoTokenizer.from_pretrained(tok_name)) + except Exception as exc: # noqa: BLE001 + logging.warning("debug preds: tokenizer load failed: %s", exc) + return + + ids = debug["input_ids"] + labels = debug["labels"] + preds = debug["predictions"] + attn = debug["attention_mask"] + + n = ids.shape[0] + print( + f"\n========== STEP {step} DEBUG PREDICTIONS ({n} samples) ==========", + flush=True, + ) + for s in range(n): + a = attn[s].tolist() + real = sum(a) + sid = ids[s].tolist() + sl = labels[s].tolist() + sp = preds[s].tolist() + prompt = tokenizer.decode(sid[:real], skip_special_tokens=False) + print(f"\n --- sample {s + 1}/{n} ---", flush=True) + print(f" prompt: {prompt!r}", flush=True) + n_sup = n_ok = 0 + rows: list[str] = [] + # CE shift: pred[t] predicts label[t+1]. Iterate supervised label + # positions (i = t+1) and align with prediction at t = i-1. + for i in range(1, real): + label = sl[i] + if label == -100: + continue + n_sup += 1 + pred = sp[i - 1] + ok = label == pred + n_ok += int(ok) + lbl_str = tokenizer.decode([label]) if 0 <= label < tokenizer.vocab_size + 2048 else "" + pred_str = tokenizer.decode([pred]) if 0 <= pred < tokenizer.vocab_size + 2048 else "" + mark = "✓" if ok else "✗" + rows.append( + f" pos {i - 1:3d} → {i:3d}: label {label:6d} {lbl_str!r:20s} | " + f"pred {pred:6d} {pred_str!r:20s} {mark}" + ) + for r in rows: + print(r, flush=True) + acc = n_ok / max(n_sup, 1) + print(f" token accuracy: {n_ok}/{n_sup} = {acc:.1%}", flush=True) + print("=" * 60 + "\n", flush=True) + + def _build_vqa_oversample_weights(dataset: Any, target_fraction: float) -> "torch.Tensor | None": """Build per-frame sampling weights that oversample VQA-annotated frames. @@ -542,6 +627,27 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps is_eval_step = cfg.eval_freq > 0 and step % cfg.eval_freq == 0 + # Optional periodic head-prediction dump for the LM head: + # ``LEROBOT_DEBUG_PREDS_EVERY=1000`` prints 5 samples + per-token + # (label, argmax, ✓/✗) every 1000 steps. Cheap diagnostic to see + # whether the text head is actually learning what we expect, vs + # collapsing to a fixed token. Refilling the recipe-sample dump + # budget at the same cadence also redumps the raw input shapes. + _debug_preds_every = int(os.environ.get("LEROBOT_DEBUG_PREDS_EVERY", "0")) + if ( + _debug_preds_every > 0 + and step % _debug_preds_every == 0 + and is_main_process + ): + try: + from lerobot.policies.pi052 import text_processor_pi052 as _tp # noqa: PLC0415 + + _tp._DUMPED_SO_FAR = 0 + _tp._DUMP_BUDGET = max(_tp._DUMP_BUDGET, 5) + except Exception: # noqa: BLE001 + pass + _print_debug_text_predictions(policy, batch, step, n_samples=5) + if is_log_step: logging.info(train_tracker) if wandb_logger: