diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 555ae7738..c0b30f707 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -856,11 +856,54 @@ class PI052Policy(PI05Policy): lm_head = backbone.paligemma.lm_head text_logits = lm_head(text_hidden.to(lm_head.weight.dtype)) preds = text_logits.argmax(dim=-1) + + # Train/inference parity check — run select_message on the + # *same* prompt prefix (the language up to but not including + # the supervised span) and capture the auto-regressive + # generation. The first generated token MUST match the + # training-side argmax at the prompt-end position (both are + # ``argmax lm_head(h_last_prompt)`` over identical context); + # any divergence is a parity bug (mask, dtype, KI routing + # difference). Later tokens can diverge because training + # uses teacher forcing while inference free-runs. + inference_outputs: list[dict[str, Any]] = [] + for s in range(n): + row_labels = sub_labels[s] + sup_pos = (row_labels != -100).nonzero(as_tuple=True)[0] + if sup_pos.numel() == 0: + inference_outputs.append({"first_token": None, "decoded": ""}) + continue + first_sup = int(sup_pos[0].item()) + # Build a single-sample batch with attention zeroed past + # the supervised span — that gives ``embed_prefix`` only + # the user-prompt portion to attend over. + prompt_mask = sub[OBS_LANGUAGE_ATTENTION_MASK][s : s + 1].clone() + prompt_mask[:, first_sup:] = 0 + inf_batch: dict[str, Any] = { + OBS_LANGUAGE_TOKENS: sub[OBS_LANGUAGE_TOKENS][s : s + 1], + OBS_LANGUAGE_ATTENTION_MASK: prompt_mask, + } + for k, v in sub.items(): + if isinstance(k, str) and k.startswith("observation.images."): + inf_batch[k] = v[s : s + 1] + if "observation.state" in batch and torch.is_tensor(batch["observation.state"]): + inf_batch["observation.state"] = batch["observation.state"][s : s + 1] + try: + # Tight budget — we just want to see the model's + # opening continuation, not the full sequence. + decoded = self.select_message( + inf_batch, max_new_tokens=24, temperature=0.0, top_p=1.0 + ) + except Exception as exc: # noqa: BLE001 + decoded = f"" + inference_outputs.append({"first_sup_pos": first_sup, "decoded": decoded}) + return { "input_ids": lang_tokens.detach().cpu(), "attention_mask": lang_masks.detach().cpu(), "labels": sub_labels.detach().cpu(), "predictions": preds.detach().cpu(), + "inference": inference_outputs, } finally: if was_training: diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 4e13abff1..3a64912ed 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -200,6 +200,7 @@ def _print_debug_text_predictions( labels = debug["labels"] preds = debug["predictions"] attn = debug["attention_mask"] + inference = debug.get("inference") or [] n = ids.shape[0] print( @@ -215,29 +216,60 @@ def _print_debug_text_predictions( prompt = tokenizer.decode(sid[:real], skip_special_tokens=False) print(f"\n --- sample {s + 1}/{n} ---", flush=True) print(f" prompt: {prompt!r}", flush=True) + + # Ground-truth target (the contiguous supervised label span). + sup_ids = [int(sid[i]) for i in range(real) if sl[i] != -100] + if sup_ids: + print( + f" target (ground truth) : {tokenizer.decode(sup_ids, skip_special_tokens=False)!r}", + flush=True, + ) + + # Training-side teacher-forced argmax on the same prompt+target. n_sup = n_ok = 0 - rows: list[str] = [] - # CE shift: pred[t] predicts label[t+1]. Iterate supervised label - # positions (i = t+1) and align with prediction at t = i-1. + first_sup_pred: int | None = None + teacher_chars: list[int] = [] for i in range(1, real): label = sl[i] if label == -100: continue n_sup += 1 - pred = sp[i - 1] - ok = label == pred - n_ok += int(ok) - lbl_str = tokenizer.decode([label]) if 0 <= label < tokenizer.vocab_size + 2048 else "" - pred_str = tokenizer.decode([pred]) if 0 <= pred < tokenizer.vocab_size + 2048 else "" - mark = "✓" if ok else "✗" - rows.append( - f" pos {i - 1:3d} → {i:3d}: label {label:6d} {lbl_str!r:20s} | " - f"pred {pred:6d} {pred_str!r:20s} {mark}" - ) - for r in rows: - print(r, flush=True) + pred = int(sp[i - 1]) + if first_sup_pred is None: + first_sup_pred = pred + teacher_chars.append(pred) + if label == pred: + n_ok += 1 + teacher_text = ( + tokenizer.decode(teacher_chars, skip_special_tokens=False) if teacher_chars else "" + ) acc = n_ok / max(n_sup, 1) - print(f" token accuracy: {n_ok}/{n_sup} = {acc:.1%}", flush=True) + print( + f" training argmax (teacher-fed) : {teacher_text!r} acc={n_ok}/{n_sup}={acc:.1%}", + flush=True, + ) + + # Inference-side autoregressive output from the same prompt prefix. + inf_entry = inference[s] if s < len(inference) else None + if inf_entry: + inf_decoded = inf_entry.get("decoded", "") + print(f" inference (autoregressive) : {inf_decoded!r}", flush=True) + # First-token parity: training-side argmax at the prompt-end + # position MUST equal inference's first generated token — + # both compute argmax(lm_head(h_last_prompt)) on identical + # context. Any divergence signals a training↔inference bug. + if first_sup_pred is not None and inf_decoded and not inf_decoded.startswith("