feat(debug): LEROBOT_DUMP_RECIPE_SAMPLES=N dumps the first N rendered samples

Adds a one-shot debug dumper to both chat processors. When the env
var ``LEROBOT_DUMP_RECIPE_SAMPLES`` is set to a positive integer N,
the next N samples processed (rank-0 only) get pretty-printed:

* the recipe-rendered messages (role / stream / target / content),
* the full tokenized prompt (decoded back),
* inline ``[TGT]...[/TGT]`` markers over the spans the LM head is
  supervised on,
* token count + target-token count,
* ``predict_actions`` flag.

Usage:

  LEROBOT_DUMP_RECIPE_SAMPLES=5 sbatch train_smolvla2.slurm

After N dumps the helper becomes a no-op; training continues
unaffected. Works for both smolvla2 (chat-template renderer) and
pi052 (plain ``Role: content`` concat renderer); each processor has
its own copy to avoid cross-package imports.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-13 15:21:46 +02:00
parent 2c920ab178
commit 841d3c47e1
2 changed files with 178 additions and 0 deletions
@@ -37,6 +37,7 @@ Outputs:
from __future__ import annotations
import logging
import os
from dataclasses import dataclass
from typing import Any
@@ -50,6 +51,80 @@ from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TO
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Debug helper — see ``chat_processor_smolvla2._dump_recipe_sample`` for the
# matching SmolVLA2 implementation. Behaviour: when
# ``LEROBOT_DUMP_RECIPE_SAMPLES=N`` is set, the next N samples processed (on
# rank 0) are pretty-printed with ``[TGT]...[/TGT]`` markers over the spans
# the LM head will be supervised on.
# ---------------------------------------------------------------------------
_DUMP_BUDGET = int(os.environ.get("LEROBOT_DUMP_RECIPE_SAMPLES", "0"))
_DUMPED_SO_FAR = 0
def _is_dump_rank() -> bool:
rank = os.environ.get("RANK") or os.environ.get("LOCAL_RANK") or "0"
try:
return int(rank) == 0
except ValueError:
return True
def _dump_recipe_sample(
*,
messages: list[dict[str, Any]],
prompt_text: str,
token_ids: list[int],
labels: list[int],
predict_actions: bool,
tokenizer: Any,
) -> None:
"""Pretty-print one rendered sample. Stops once the global budget is hit."""
global _DUMPED_SO_FAR
if _DUMPED_SO_FAR >= _DUMP_BUDGET or not _is_dump_rank():
return
_DUMPED_SO_FAR += 1
parts: list[str] = []
i = 0
while i < len(labels):
if labels[i] == -100:
j = i
while j < len(labels) and labels[j] == -100:
j += 1
parts.append(tokenizer.decode(token_ids[i:j], skip_special_tokens=False))
i = j
else:
j = i
while j < len(labels) and labels[j] != -100:
j += 1
tgt_text = tokenizer.decode(token_ids[i:j], skip_special_tokens=False)
parts.append(f"[TGT]{tgt_text}[/TGT]")
i = j
annotated = "".join(parts)
n_tgt = sum(1 for l in labels if l != -100)
print(
"\n========== RECIPE SAMPLE DUMP "
f"({_DUMPED_SO_FAR}/{_DUMP_BUDGET}) ==========",
flush=True,
)
print(f" predict_actions: {predict_actions}", flush=True)
print(f" rendered messages ({len(messages)}):", flush=True)
for m in messages:
stream = m.get("stream")
target = m.get("target")
role = m.get("role")
content = m.get("content")
print(f" - role={role} stream={stream} target={target}", flush=True)
print(f" content: {content!r}", flush=True)
print(f" rendered prompt:\n {prompt_text!r}", flush=True)
print(f" token count: {len(token_ids)} (target tokens: {n_tgt})", flush=True)
print(f" decoded (with target markers):\n {annotated}", flush=True)
print("==============================================\n", flush=True)
def _strip_blocks(message: dict[str, Any]) -> dict[str, Any]:
"""Normalise a message's content to a plain string.
@@ -221,6 +296,16 @@ class PI052TextTokenizerStep(ProcessorStep):
dtype=torch.bool,
)
if _DUMP_BUDGET > 0:
_dump_recipe_sample(
messages=messages,
prompt_text=prompt,
token_ids=input_ids.tolist(),
labels=labels.tolist(),
predict_actions=bool(predict_actions.item()),
tokenizer=tokenizer,
)
obs = dict(transition.get(TransitionKey.OBSERVATION) or {})
obs[OBS_LANGUAGE_TOKENS] = input_ids.unsqueeze(0)
obs[OBS_LANGUAGE_ATTENTION_MASK] = attention_mask.unsqueeze(0)
@@ -39,6 +39,7 @@ matching the chat-template-stripped text order).
from __future__ import annotations
import logging
import os
from dataclasses import dataclass
from typing import Any
@@ -52,6 +53,84 @@ from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TO
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Debug helper: dump the first N rendered samples to stdout so you can sanity-
# check what the model actually sees before kicking off a long training run.
#
# LEROBOT_DUMP_RECIPE_SAMPLES=5 lerobot-train ...
#
# Prints the recipe-rendered messages, the chat-templated text (decoded back
# from token ids), and inline ``[TGT]...[/TGT]`` markers showing which spans
# are supervised by text-CE. Stops after N total dumps to keep training logs
# tractable. Rank-0 only when accelerate sets ``RANK``.
# ---------------------------------------------------------------------------
_DUMP_BUDGET = int(os.environ.get("LEROBOT_DUMP_RECIPE_SAMPLES", "0"))
_DUMPED_SO_FAR = 0
def _is_dump_rank() -> bool:
rank = os.environ.get("RANK") or os.environ.get("LOCAL_RANK") or "0"
try:
return int(rank) == 0
except ValueError:
return True
def _dump_recipe_sample(
*,
messages: list[dict[str, Any]],
token_ids: list[int],
labels: list[int],
predict_actions: bool,
tokenizer: Any,
) -> None:
"""Pretty-print one rendered sample. Stops once the global budget is hit."""
global _DUMPED_SO_FAR
if _DUMPED_SO_FAR >= _DUMP_BUDGET or not _is_dump_rank():
return
_DUMPED_SO_FAR += 1
decoded = tokenizer.decode(token_ids, skip_special_tokens=False)
parts: list[str] = []
i = 0
while i < len(labels):
if labels[i] == -100:
j = i
while j < len(labels) and labels[j] == -100:
j += 1
parts.append(tokenizer.decode(token_ids[i:j], skip_special_tokens=False))
i = j
else:
j = i
while j < len(labels) and labels[j] != -100:
j += 1
tgt_text = tokenizer.decode(token_ids[i:j], skip_special_tokens=False)
parts.append(f"[TGT]{tgt_text}[/TGT]")
i = j
annotated = "".join(parts)
n_tgt = sum(1 for l in labels if l != -100)
print(
"\n========== RECIPE SAMPLE DUMP "
f"({_DUMPED_SO_FAR}/{_DUMP_BUDGET}) ==========",
flush=True,
)
print(f" predict_actions: {predict_actions}", flush=True)
print(f" rendered messages ({len(messages)}):", flush=True)
for m in messages:
stream = m.get("stream")
target = m.get("target")
role = m.get("role")
content = m.get("content")
print(f" - role={role} stream={stream} target={target}", flush=True)
print(f" content: {content!r}", flush=True)
print(f" token count: {len(token_ids)} (target tokens: {n_tgt})", flush=True)
print(f" decoded (raw):\n {decoded}", flush=True)
print(f" decoded (with target markers):\n {annotated}", flush=True)
print("==============================================\n", flush=True)
@dataclass
@ProcessorStepRegistry.register(name="smolvla2_chat_tokenizer")
class SmolVLA2ChatTokenizerStep(ProcessorStep):
@@ -156,6 +235,20 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
)
]
# Optional first-N-samples debug dump for sanity-checking what the
# model actually sees. No-op unless ``LEROBOT_DUMP_RECIPE_SAMPLES``
# is set; stops globally after the budget is exhausted.
if _DUMP_BUDGET > 0:
msgs_iter = messages if _is_batched_messages(messages) else [messages]
for msg, (ids, labels, predict_action) in zip(msgs_iter, encoded, strict=False):
_dump_recipe_sample(
messages=msg,
token_ids=ids,
labels=labels,
predict_actions=predict_action,
tokenizer=tokenizer,
)
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
target_length = self.max_length if self.padding == "max_length" else max(
len(ids) for ids, _, _ in encoded