mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-29 06:07:40 +00:00
feat(debug): LEROBOT_DUMP_RECIPE_SAMPLES=N dumps the first N rendered samples
Adds a one-shot debug dumper to both chat processors. When the env var ``LEROBOT_DUMP_RECIPE_SAMPLES`` is set to a positive integer N, the next N samples processed (rank-0 only) get pretty-printed: * the recipe-rendered messages (role / stream / target / content), * the full tokenized prompt (decoded back), * inline ``[TGT]...[/TGT]`` markers over the spans the LM head is supervised on, * token count + target-token count, * ``predict_actions`` flag. Usage: LEROBOT_DUMP_RECIPE_SAMPLES=5 sbatch train_smolvla2.slurm After N dumps the helper becomes a no-op; training continues unaffected. Works for both smolvla2 (chat-template renderer) and pi052 (plain ``Role: content`` concat renderer); each processor has its own copy to avoid cross-package imports. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -37,6 +37,7 @@ Outputs:
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
@@ -50,6 +51,80 @@ from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TO
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Debug helper — see ``chat_processor_smolvla2._dump_recipe_sample`` for the
|
||||
# matching SmolVLA2 implementation. Behaviour: when
|
||||
# ``LEROBOT_DUMP_RECIPE_SAMPLES=N`` is set, the next N samples processed (on
|
||||
# rank 0) are pretty-printed with ``[TGT]...[/TGT]`` markers over the spans
|
||||
# the LM head will be supervised on.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DUMP_BUDGET = int(os.environ.get("LEROBOT_DUMP_RECIPE_SAMPLES", "0"))
|
||||
_DUMPED_SO_FAR = 0
|
||||
|
||||
|
||||
def _is_dump_rank() -> bool:
|
||||
rank = os.environ.get("RANK") or os.environ.get("LOCAL_RANK") or "0"
|
||||
try:
|
||||
return int(rank) == 0
|
||||
except ValueError:
|
||||
return True
|
||||
|
||||
|
||||
def _dump_recipe_sample(
|
||||
*,
|
||||
messages: list[dict[str, Any]],
|
||||
prompt_text: str,
|
||||
token_ids: list[int],
|
||||
labels: list[int],
|
||||
predict_actions: bool,
|
||||
tokenizer: Any,
|
||||
) -> None:
|
||||
"""Pretty-print one rendered sample. Stops once the global budget is hit."""
|
||||
global _DUMPED_SO_FAR
|
||||
if _DUMPED_SO_FAR >= _DUMP_BUDGET or not _is_dump_rank():
|
||||
return
|
||||
_DUMPED_SO_FAR += 1
|
||||
|
||||
parts: list[str] = []
|
||||
i = 0
|
||||
while i < len(labels):
|
||||
if labels[i] == -100:
|
||||
j = i
|
||||
while j < len(labels) and labels[j] == -100:
|
||||
j += 1
|
||||
parts.append(tokenizer.decode(token_ids[i:j], skip_special_tokens=False))
|
||||
i = j
|
||||
else:
|
||||
j = i
|
||||
while j < len(labels) and labels[j] != -100:
|
||||
j += 1
|
||||
tgt_text = tokenizer.decode(token_ids[i:j], skip_special_tokens=False)
|
||||
parts.append(f"[TGT]{tgt_text}[/TGT]")
|
||||
i = j
|
||||
annotated = "".join(parts)
|
||||
|
||||
n_tgt = sum(1 for l in labels if l != -100)
|
||||
print(
|
||||
"\n========== RECIPE SAMPLE DUMP "
|
||||
f"({_DUMPED_SO_FAR}/{_DUMP_BUDGET}) ==========",
|
||||
flush=True,
|
||||
)
|
||||
print(f" predict_actions: {predict_actions}", flush=True)
|
||||
print(f" rendered messages ({len(messages)}):", flush=True)
|
||||
for m in messages:
|
||||
stream = m.get("stream")
|
||||
target = m.get("target")
|
||||
role = m.get("role")
|
||||
content = m.get("content")
|
||||
print(f" - role={role} stream={stream} target={target}", flush=True)
|
||||
print(f" content: {content!r}", flush=True)
|
||||
print(f" rendered prompt:\n {prompt_text!r}", flush=True)
|
||||
print(f" token count: {len(token_ids)} (target tokens: {n_tgt})", flush=True)
|
||||
print(f" decoded (with target markers):\n {annotated}", flush=True)
|
||||
print("==============================================\n", flush=True)
|
||||
|
||||
|
||||
def _strip_blocks(message: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Normalise a message's content to a plain string.
|
||||
|
||||
@@ -221,6 +296,16 @@ class PI052TextTokenizerStep(ProcessorStep):
|
||||
dtype=torch.bool,
|
||||
)
|
||||
|
||||
if _DUMP_BUDGET > 0:
|
||||
_dump_recipe_sample(
|
||||
messages=messages,
|
||||
prompt_text=prompt,
|
||||
token_ids=input_ids.tolist(),
|
||||
labels=labels.tolist(),
|
||||
predict_actions=bool(predict_actions.item()),
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
obs = dict(transition.get(TransitionKey.OBSERVATION) or {})
|
||||
obs[OBS_LANGUAGE_TOKENS] = input_ids.unsqueeze(0)
|
||||
obs[OBS_LANGUAGE_ATTENTION_MASK] = attention_mask.unsqueeze(0)
|
||||
|
||||
@@ -39,6 +39,7 @@ matching the chat-template-stripped text order).
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
@@ -52,6 +53,84 @@ from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TO
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Debug helper: dump the first N rendered samples to stdout so you can sanity-
|
||||
# check what the model actually sees before kicking off a long training run.
|
||||
#
|
||||
# LEROBOT_DUMP_RECIPE_SAMPLES=5 lerobot-train ...
|
||||
#
|
||||
# Prints the recipe-rendered messages, the chat-templated text (decoded back
|
||||
# from token ids), and inline ``[TGT]...[/TGT]`` markers showing which spans
|
||||
# are supervised by text-CE. Stops after N total dumps to keep training logs
|
||||
# tractable. Rank-0 only when accelerate sets ``RANK``.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DUMP_BUDGET = int(os.environ.get("LEROBOT_DUMP_RECIPE_SAMPLES", "0"))
|
||||
_DUMPED_SO_FAR = 0
|
||||
|
||||
|
||||
def _is_dump_rank() -> bool:
|
||||
rank = os.environ.get("RANK") or os.environ.get("LOCAL_RANK") or "0"
|
||||
try:
|
||||
return int(rank) == 0
|
||||
except ValueError:
|
||||
return True
|
||||
|
||||
|
||||
def _dump_recipe_sample(
|
||||
*,
|
||||
messages: list[dict[str, Any]],
|
||||
token_ids: list[int],
|
||||
labels: list[int],
|
||||
predict_actions: bool,
|
||||
tokenizer: Any,
|
||||
) -> None:
|
||||
"""Pretty-print one rendered sample. Stops once the global budget is hit."""
|
||||
global _DUMPED_SO_FAR
|
||||
if _DUMPED_SO_FAR >= _DUMP_BUDGET or not _is_dump_rank():
|
||||
return
|
||||
_DUMPED_SO_FAR += 1
|
||||
|
||||
decoded = tokenizer.decode(token_ids, skip_special_tokens=False)
|
||||
parts: list[str] = []
|
||||
i = 0
|
||||
while i < len(labels):
|
||||
if labels[i] == -100:
|
||||
j = i
|
||||
while j < len(labels) and labels[j] == -100:
|
||||
j += 1
|
||||
parts.append(tokenizer.decode(token_ids[i:j], skip_special_tokens=False))
|
||||
i = j
|
||||
else:
|
||||
j = i
|
||||
while j < len(labels) and labels[j] != -100:
|
||||
j += 1
|
||||
tgt_text = tokenizer.decode(token_ids[i:j], skip_special_tokens=False)
|
||||
parts.append(f"[TGT]{tgt_text}[/TGT]")
|
||||
i = j
|
||||
annotated = "".join(parts)
|
||||
|
||||
n_tgt = sum(1 for l in labels if l != -100)
|
||||
print(
|
||||
"\n========== RECIPE SAMPLE DUMP "
|
||||
f"({_DUMPED_SO_FAR}/{_DUMP_BUDGET}) ==========",
|
||||
flush=True,
|
||||
)
|
||||
print(f" predict_actions: {predict_actions}", flush=True)
|
||||
print(f" rendered messages ({len(messages)}):", flush=True)
|
||||
for m in messages:
|
||||
stream = m.get("stream")
|
||||
target = m.get("target")
|
||||
role = m.get("role")
|
||||
content = m.get("content")
|
||||
print(f" - role={role} stream={stream} target={target}", flush=True)
|
||||
print(f" content: {content!r}", flush=True)
|
||||
print(f" token count: {len(token_ids)} (target tokens: {n_tgt})", flush=True)
|
||||
print(f" decoded (raw):\n {decoded}", flush=True)
|
||||
print(f" decoded (with target markers):\n {annotated}", flush=True)
|
||||
print("==============================================\n", flush=True)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="smolvla2_chat_tokenizer")
|
||||
class SmolVLA2ChatTokenizerStep(ProcessorStep):
|
||||
@@ -156,6 +235,20 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
|
||||
)
|
||||
]
|
||||
|
||||
# Optional first-N-samples debug dump for sanity-checking what the
|
||||
# model actually sees. No-op unless ``LEROBOT_DUMP_RECIPE_SAMPLES``
|
||||
# is set; stops globally after the budget is exhausted.
|
||||
if _DUMP_BUDGET > 0:
|
||||
msgs_iter = messages if _is_batched_messages(messages) else [messages]
|
||||
for msg, (ids, labels, predict_action) in zip(msgs_iter, encoded, strict=False):
|
||||
_dump_recipe_sample(
|
||||
messages=msg,
|
||||
token_ids=ids,
|
||||
labels=labels,
|
||||
predict_actions=predict_action,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
|
||||
target_length = self.max_length if self.padding == "max_length" else max(
|
||||
len(ids) for ids, _, _ in encoded
|
||||
|
||||
Reference in New Issue
Block a user