mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
feat(train): debug dump runs inference too, with parity check
Extends the periodic LM-head dump (LEROBOT_DEBUG_PREDS_EVERY) to ALSO
run select_message autoregressively on the same prompt prefix and show:
prompt : '<bos>User: ... Assistant: '
target (ground truth) : ' close the gripper ...'
training argmax (teacher-fed) : ' close the gri lift ...' acc=12/15=80%
inference (autoregressive) : ' close the gripper around ...'
first-token parity : train=3387 (' close') vs infer=3387 (' close') ✓ MATCH
The first-token parity check is decisive: training-side argmax at the
prompt-end position and inference's first generated token both compute
``argmax(lm_head(h_last_prompt))`` on identical context, so they MUST
match. Any divergence signals a training↔inference bug (mask, dtype,
KI routing, embedding scale, etc.). Subsequent tokens can diverge
because training uses teacher forcing while inference free-runs.
debug_text_predictions now also returns an ``inference`` list keyed
by sample, each entry carrying ``first_sup_pos`` and ``decoded``.
Limited to 24 new tokens per sample to keep the dump fast.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -856,11 +856,54 @@ class PI052Policy(PI05Policy):
|
|||||||
lm_head = backbone.paligemma.lm_head
|
lm_head = backbone.paligemma.lm_head
|
||||||
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
|
text_logits = lm_head(text_hidden.to(lm_head.weight.dtype))
|
||||||
preds = text_logits.argmax(dim=-1)
|
preds = text_logits.argmax(dim=-1)
|
||||||
|
|
||||||
|
# Train/inference parity check — run select_message on the
|
||||||
|
# *same* prompt prefix (the language up to but not including
|
||||||
|
# the supervised span) and capture the auto-regressive
|
||||||
|
# generation. The first generated token MUST match the
|
||||||
|
# training-side argmax at the prompt-end position (both are
|
||||||
|
# ``argmax lm_head(h_last_prompt)`` over identical context);
|
||||||
|
# any divergence is a parity bug (mask, dtype, KI routing
|
||||||
|
# difference). Later tokens can diverge because training
|
||||||
|
# uses teacher forcing while inference free-runs.
|
||||||
|
inference_outputs: list[dict[str, Any]] = []
|
||||||
|
for s in range(n):
|
||||||
|
row_labels = sub_labels[s]
|
||||||
|
sup_pos = (row_labels != -100).nonzero(as_tuple=True)[0]
|
||||||
|
if sup_pos.numel() == 0:
|
||||||
|
inference_outputs.append({"first_token": None, "decoded": ""})
|
||||||
|
continue
|
||||||
|
first_sup = int(sup_pos[0].item())
|
||||||
|
# Build a single-sample batch with attention zeroed past
|
||||||
|
# the supervised span — that gives ``embed_prefix`` only
|
||||||
|
# the user-prompt portion to attend over.
|
||||||
|
prompt_mask = sub[OBS_LANGUAGE_ATTENTION_MASK][s : s + 1].clone()
|
||||||
|
prompt_mask[:, first_sup:] = 0
|
||||||
|
inf_batch: dict[str, Any] = {
|
||||||
|
OBS_LANGUAGE_TOKENS: sub[OBS_LANGUAGE_TOKENS][s : s + 1],
|
||||||
|
OBS_LANGUAGE_ATTENTION_MASK: prompt_mask,
|
||||||
|
}
|
||||||
|
for k, v in sub.items():
|
||||||
|
if isinstance(k, str) and k.startswith("observation.images."):
|
||||||
|
inf_batch[k] = v[s : s + 1]
|
||||||
|
if "observation.state" in batch and torch.is_tensor(batch["observation.state"]):
|
||||||
|
inf_batch["observation.state"] = batch["observation.state"][s : s + 1]
|
||||||
|
try:
|
||||||
|
# Tight budget — we just want to see the model's
|
||||||
|
# opening continuation, not the full sequence.
|
||||||
|
decoded = self.select_message(
|
||||||
|
inf_batch, max_new_tokens=24, temperature=0.0, top_p=1.0
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
decoded = f"<inference failed: {type(exc).__name__}: {exc}>"
|
||||||
|
inference_outputs.append({"first_sup_pos": first_sup, "decoded": decoded})
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"input_ids": lang_tokens.detach().cpu(),
|
"input_ids": lang_tokens.detach().cpu(),
|
||||||
"attention_mask": lang_masks.detach().cpu(),
|
"attention_mask": lang_masks.detach().cpu(),
|
||||||
"labels": sub_labels.detach().cpu(),
|
"labels": sub_labels.detach().cpu(),
|
||||||
"predictions": preds.detach().cpu(),
|
"predictions": preds.detach().cpu(),
|
||||||
|
"inference": inference_outputs,
|
||||||
}
|
}
|
||||||
finally:
|
finally:
|
||||||
if was_training:
|
if was_training:
|
||||||
|
|||||||
@@ -200,6 +200,7 @@ def _print_debug_text_predictions(
|
|||||||
labels = debug["labels"]
|
labels = debug["labels"]
|
||||||
preds = debug["predictions"]
|
preds = debug["predictions"]
|
||||||
attn = debug["attention_mask"]
|
attn = debug["attention_mask"]
|
||||||
|
inference = debug.get("inference") or []
|
||||||
|
|
||||||
n = ids.shape[0]
|
n = ids.shape[0]
|
||||||
print(
|
print(
|
||||||
@@ -215,29 +216,60 @@ def _print_debug_text_predictions(
|
|||||||
prompt = tokenizer.decode(sid[:real], skip_special_tokens=False)
|
prompt = tokenizer.decode(sid[:real], skip_special_tokens=False)
|
||||||
print(f"\n --- sample {s + 1}/{n} ---", flush=True)
|
print(f"\n --- sample {s + 1}/{n} ---", flush=True)
|
||||||
print(f" prompt: {prompt!r}", flush=True)
|
print(f" prompt: {prompt!r}", flush=True)
|
||||||
|
|
||||||
|
# Ground-truth target (the contiguous supervised label span).
|
||||||
|
sup_ids = [int(sid[i]) for i in range(real) if sl[i] != -100]
|
||||||
|
if sup_ids:
|
||||||
|
print(
|
||||||
|
f" target (ground truth) : {tokenizer.decode(sup_ids, skip_special_tokens=False)!r}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training-side teacher-forced argmax on the same prompt+target.
|
||||||
n_sup = n_ok = 0
|
n_sup = n_ok = 0
|
||||||
rows: list[str] = []
|
first_sup_pred: int | None = None
|
||||||
# CE shift: pred[t] predicts label[t+1]. Iterate supervised label
|
teacher_chars: list[int] = []
|
||||||
# positions (i = t+1) and align with prediction at t = i-1.
|
|
||||||
for i in range(1, real):
|
for i in range(1, real):
|
||||||
label = sl[i]
|
label = sl[i]
|
||||||
if label == -100:
|
if label == -100:
|
||||||
continue
|
continue
|
||||||
n_sup += 1
|
n_sup += 1
|
||||||
pred = sp[i - 1]
|
pred = int(sp[i - 1])
|
||||||
ok = label == pred
|
if first_sup_pred is None:
|
||||||
n_ok += int(ok)
|
first_sup_pred = pred
|
||||||
lbl_str = tokenizer.decode([label]) if 0 <= label < tokenizer.vocab_size + 2048 else "<oob>"
|
teacher_chars.append(pred)
|
||||||
pred_str = tokenizer.decode([pred]) if 0 <= pred < tokenizer.vocab_size + 2048 else "<oob>"
|
if label == pred:
|
||||||
mark = "✓" if ok else "✗"
|
n_ok += 1
|
||||||
rows.append(
|
teacher_text = (
|
||||||
f" pos {i - 1:3d} → {i:3d}: label {label:6d} {lbl_str!r:20s} | "
|
tokenizer.decode(teacher_chars, skip_special_tokens=False) if teacher_chars else ""
|
||||||
f"pred {pred:6d} {pred_str!r:20s} {mark}"
|
)
|
||||||
)
|
|
||||||
for r in rows:
|
|
||||||
print(r, flush=True)
|
|
||||||
acc = n_ok / max(n_sup, 1)
|
acc = n_ok / max(n_sup, 1)
|
||||||
print(f" token accuracy: {n_ok}/{n_sup} = {acc:.1%}", flush=True)
|
print(
|
||||||
|
f" training argmax (teacher-fed) : {teacher_text!r} acc={n_ok}/{n_sup}={acc:.1%}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Inference-side autoregressive output from the same prompt prefix.
|
||||||
|
inf_entry = inference[s] if s < len(inference) else None
|
||||||
|
if inf_entry:
|
||||||
|
inf_decoded = inf_entry.get("decoded", "")
|
||||||
|
print(f" inference (autoregressive) : {inf_decoded!r}", flush=True)
|
||||||
|
# First-token parity: training-side argmax at the prompt-end
|
||||||
|
# position MUST equal inference's first generated token —
|
||||||
|
# both compute argmax(lm_head(h_last_prompt)) on identical
|
||||||
|
# context. Any divergence signals a training↔inference bug.
|
||||||
|
if first_sup_pred is not None and inf_decoded and not inf_decoded.startswith("<inference"):
|
||||||
|
inf_ids = tokenizer(inf_decoded, add_special_tokens=False)["input_ids"]
|
||||||
|
if inf_ids:
|
||||||
|
inf_first = int(inf_ids[0])
|
||||||
|
match = inf_first == first_sup_pred
|
||||||
|
print(
|
||||||
|
f" first-token parity : "
|
||||||
|
f"train={first_sup_pred} ({tokenizer.decode([first_sup_pred])!r}) "
|
||||||
|
f"vs infer={inf_first} ({tokenizer.decode([inf_first])!r}) "
|
||||||
|
f"{'✓ MATCH' if match else '✗ DIVERGED — training/inference mismatch'}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
print("=" * 60 + "\n", flush=True)
|
print("=" * 60 + "\n", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user