diff --git a/src/lerobot/policies/pi052/text_processor_pi052.py b/src/lerobot/policies/pi052/text_processor_pi052.py index 1dcedbfc1..fddcba9df 100644 --- a/src/lerobot/policies/pi052/text_processor_pi052.py +++ b/src/lerobot/policies/pi052/text_processor_pi052.py @@ -37,6 +37,7 @@ Outputs: from __future__ import annotations import logging +import os from dataclasses import dataclass from typing import Any @@ -50,6 +51,80 @@ from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TO logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# Debug helper — see ``chat_processor_smolvla2._dump_recipe_sample`` for the +# matching SmolVLA2 implementation. Behaviour: when +# ``LEROBOT_DUMP_RECIPE_SAMPLES=N`` is set, the next N samples processed (on +# rank 0) are pretty-printed with ``[TGT]...[/TGT]`` markers over the spans +# the LM head will be supervised on. +# --------------------------------------------------------------------------- + +_DUMP_BUDGET = int(os.environ.get("LEROBOT_DUMP_RECIPE_SAMPLES", "0")) +_DUMPED_SO_FAR = 0 + + +def _is_dump_rank() -> bool: + rank = os.environ.get("RANK") or os.environ.get("LOCAL_RANK") or "0" + try: + return int(rank) == 0 + except ValueError: + return True + + +def _dump_recipe_sample( + *, + messages: list[dict[str, Any]], + prompt_text: str, + token_ids: list[int], + labels: list[int], + predict_actions: bool, + tokenizer: Any, +) -> None: + """Pretty-print one rendered sample. Stops once the global budget is hit.""" + global _DUMPED_SO_FAR + if _DUMPED_SO_FAR >= _DUMP_BUDGET or not _is_dump_rank(): + return + _DUMPED_SO_FAR += 1 + + parts: list[str] = [] + i = 0 + while i < len(labels): + if labels[i] == -100: + j = i + while j < len(labels) and labels[j] == -100: + j += 1 + parts.append(tokenizer.decode(token_ids[i:j], skip_special_tokens=False)) + i = j + else: + j = i + while j < len(labels) and labels[j] != -100: + j += 1 + tgt_text = tokenizer.decode(token_ids[i:j], skip_special_tokens=False) + parts.append(f"[TGT]{tgt_text}[/TGT]") + i = j + annotated = "".join(parts) + + n_tgt = sum(1 for l in labels if l != -100) + print( + "\n========== RECIPE SAMPLE DUMP " + f"({_DUMPED_SO_FAR}/{_DUMP_BUDGET}) ==========", + flush=True, + ) + print(f" predict_actions: {predict_actions}", flush=True) + print(f" rendered messages ({len(messages)}):", flush=True) + for m in messages: + stream = m.get("stream") + target = m.get("target") + role = m.get("role") + content = m.get("content") + print(f" - role={role} stream={stream} target={target}", flush=True) + print(f" content: {content!r}", flush=True) + print(f" rendered prompt:\n {prompt_text!r}", flush=True) + print(f" token count: {len(token_ids)} (target tokens: {n_tgt})", flush=True) + print(f" decoded (with target markers):\n {annotated}", flush=True) + print("==============================================\n", flush=True) + + def _strip_blocks(message: dict[str, Any]) -> dict[str, Any]: """Normalise a message's content to a plain string. @@ -221,6 +296,16 @@ class PI052TextTokenizerStep(ProcessorStep): dtype=torch.bool, ) + if _DUMP_BUDGET > 0: + _dump_recipe_sample( + messages=messages, + prompt_text=prompt, + token_ids=input_ids.tolist(), + labels=labels.tolist(), + predict_actions=bool(predict_actions.item()), + tokenizer=tokenizer, + ) + obs = dict(transition.get(TransitionKey.OBSERVATION) or {}) obs[OBS_LANGUAGE_TOKENS] = input_ids.unsqueeze(0) obs[OBS_LANGUAGE_ATTENTION_MASK] = attention_mask.unsqueeze(0) diff --git a/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py b/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py index 1cf88b0fd..23a5e5730 100644 --- a/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py +++ b/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py @@ -39,6 +39,7 @@ matching the chat-template-stripped text order). from __future__ import annotations import logging +import os from dataclasses import dataclass from typing import Any @@ -52,6 +53,84 @@ from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TO logger = logging.getLogger(__name__) +# --------------------------------------------------------------------------- +# Debug helper: dump the first N rendered samples to stdout so you can sanity- +# check what the model actually sees before kicking off a long training run. +# +# LEROBOT_DUMP_RECIPE_SAMPLES=5 lerobot-train ... +# +# Prints the recipe-rendered messages, the chat-templated text (decoded back +# from token ids), and inline ``[TGT]...[/TGT]`` markers showing which spans +# are supervised by text-CE. Stops after N total dumps to keep training logs +# tractable. Rank-0 only when accelerate sets ``RANK``. +# --------------------------------------------------------------------------- + +_DUMP_BUDGET = int(os.environ.get("LEROBOT_DUMP_RECIPE_SAMPLES", "0")) +_DUMPED_SO_FAR = 0 + + +def _is_dump_rank() -> bool: + rank = os.environ.get("RANK") or os.environ.get("LOCAL_RANK") or "0" + try: + return int(rank) == 0 + except ValueError: + return True + + +def _dump_recipe_sample( + *, + messages: list[dict[str, Any]], + token_ids: list[int], + labels: list[int], + predict_actions: bool, + tokenizer: Any, +) -> None: + """Pretty-print one rendered sample. Stops once the global budget is hit.""" + global _DUMPED_SO_FAR + if _DUMPED_SO_FAR >= _DUMP_BUDGET or not _is_dump_rank(): + return + _DUMPED_SO_FAR += 1 + + decoded = tokenizer.decode(token_ids, skip_special_tokens=False) + parts: list[str] = [] + i = 0 + while i < len(labels): + if labels[i] == -100: + j = i + while j < len(labels) and labels[j] == -100: + j += 1 + parts.append(tokenizer.decode(token_ids[i:j], skip_special_tokens=False)) + i = j + else: + j = i + while j < len(labels) and labels[j] != -100: + j += 1 + tgt_text = tokenizer.decode(token_ids[i:j], skip_special_tokens=False) + parts.append(f"[TGT]{tgt_text}[/TGT]") + i = j + annotated = "".join(parts) + + n_tgt = sum(1 for l in labels if l != -100) + print( + "\n========== RECIPE SAMPLE DUMP " + f"({_DUMPED_SO_FAR}/{_DUMP_BUDGET}) ==========", + flush=True, + ) + print(f" predict_actions: {predict_actions}", flush=True) + print(f" rendered messages ({len(messages)}):", flush=True) + for m in messages: + stream = m.get("stream") + target = m.get("target") + role = m.get("role") + content = m.get("content") + print(f" - role={role} stream={stream} target={target}", flush=True) + print(f" content: {content!r}", flush=True) + print(f" token count: {len(token_ids)} (target tokens: {n_tgt})", flush=True) + print(f" decoded (raw):\n {decoded}", flush=True) + print(f" decoded (with target markers):\n {annotated}", flush=True) + print("==============================================\n", flush=True) + + @dataclass @ProcessorStepRegistry.register(name="smolvla2_chat_tokenizer") class SmolVLA2ChatTokenizerStep(ProcessorStep): @@ -156,6 +235,20 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep): ) ] + # Optional first-N-samples debug dump for sanity-checking what the + # model actually sees. No-op unless ``LEROBOT_DUMP_RECIPE_SAMPLES`` + # is set; stops globally after the budget is exhausted. + if _DUMP_BUDGET > 0: + msgs_iter = messages if _is_batched_messages(messages) else [messages] + for msg, (ids, labels, predict_action) in zip(msgs_iter, encoded, strict=False): + _dump_recipe_sample( + messages=msg, + token_ids=ids, + labels=labels, + predict_actions=predict_action, + tokenizer=tokenizer, + ) + pad_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0 target_length = self.max_length if self.padding == "max_length" else max( len(ids) for ids, _, _ in encoded