pi052: drop debug scaffolding left over from training/inference bug hunts

Three diagnostic surfaces shipped in PR3 that don't belong in a clean
release:

* ``LEROBOT_DUMP_RECIPE_SAMPLES`` env-var dump (~70 LOC in
  text_processor_pi052.py): pretty-prints the next N rendered samples
  with ``[TGT]...[/TGT]`` markers over supervised spans. One-off
  training-inspection tool — no production user, never wired into a
  CLI flag, only useful while iterating on the recipe. Drop the module
  constants, the ``_is_dump_rank`` / ``_dump_recipe_sample`` helpers,
  the call site, and the now-unused ``import os``.

* ``_log_obs_tensors_once()`` in lerobot_pi052_runtime.py: the
  docstring literally says "Used to bisect train/inference mismatches"
  — a debugging artifact from when the LM head was collapsing on the
  live robot. Logged unconditionally at WARNING level from both the
  dataset-driven and robot-driven providers, with no ``--verbose``
  gate. Drop the function, both call sites, and the ``_logged`` /
  ``_obs_logged`` flag dicts that fed them. (``_resize_logged`` is
  kept — it gates the operationally useful camera-size sanity log.)

* Defensive ``unsqueeze(0)`` block in the dataset observation
  provider: papered over an upstream bug where some preprocessor step
  could produce an unbatched tensor. ``AddBatchDimensionProcessorStep``
  is reliable in the current pipeline — pi052 tests still pass with
  the block removed. If the bug ever resurfaces it should be fixed
  at the source, not silently re-batched here.

Net: -169 LOC. All 30 ``tests/policies/pi052/`` tests pass.

The ``<loc>`` token plumbing (``register_paligemma_loc_tokens``,
``_loc_token``, ``suppress_loc_tokens`` runtime gate) is left as-is —
it's the actual mechanism for VQA spatial answers, not scaffolding,
and the ``suppress_loc_tokens=True`` callers on subtask/memory/
interjection paths and ``=False`` on the VQA path are intentional
asymmetric behaviour, not a bug-routing knob.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-25 15:07:43 +02:00
parent 1ff10b935c
commit 83d0c390da
2 changed files with 1 additions and 170 deletions
@@ -37,7 +37,6 @@ from __future__ import annotations
import json
import logging
import os
from dataclasses import dataclass
from typing import Any
@@ -52,78 +51,6 @@ from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TO
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Debug helper — when ``LEROBOT_DUMP_RECIPE_SAMPLES=N`` is set, the next N
# samples processed (on rank 0) are pretty-printed with ``[TGT]...[/TGT]``
# markers over the spans the LM head will be supervised on.
# ---------------------------------------------------------------------------
_DUMP_BUDGET = int(os.environ.get("LEROBOT_DUMP_RECIPE_SAMPLES", "0"))
_DUMPED_SO_FAR = 0
def _is_dump_rank() -> bool:
rank = os.environ.get("RANK") or os.environ.get("LOCAL_RANK") or "0"
try:
return int(rank) == 0
except ValueError:
return True
def _dump_recipe_sample(
*,
messages: list[dict[str, Any]],
prompt_text: str,
token_ids: list[int],
labels: list[int],
predict_actions: bool,
tokenizer: Any,
) -> None:
"""Pretty-print one rendered sample. Stops once the global budget is hit."""
global _DUMPED_SO_FAR
if _DUMPED_SO_FAR >= _DUMP_BUDGET or not _is_dump_rank():
return
_DUMPED_SO_FAR += 1
parts: list[str] = []
i = 0
while i < len(labels):
if labels[i] == -100:
j = i
while j < len(labels) and labels[j] == -100:
j += 1
parts.append(tokenizer.decode(token_ids[i:j], skip_special_tokens=False))
i = j
else:
j = i
while j < len(labels) and labels[j] != -100:
j += 1
tgt_text = tokenizer.decode(token_ids[i:j], skip_special_tokens=False)
parts.append(f"[TGT]{tgt_text}[/TGT]")
i = j
annotated = "".join(parts)
n_tgt = sum(1 for l in labels if l != -100)
print(
"\n========== RECIPE SAMPLE DUMP "
f"({_DUMPED_SO_FAR}/{_DUMP_BUDGET}) ==========",
flush=True,
)
print(f" predict_actions: {predict_actions}", flush=True)
print(f" rendered messages ({len(messages)}):", flush=True)
for m in messages:
stream = m.get("stream")
target = m.get("target")
role = m.get("role")
content = m.get("content")
print(f" - role={role} stream={stream} target={target}", flush=True)
print(f" content: {content!r}", flush=True)
print(f" rendered prompt:\n {prompt_text!r}", flush=True)
print(f" token count: {len(token_ids)} (target tokens: {n_tgt})", flush=True)
print(f" decoded (with target markers):\n {annotated}", flush=True)
print("==============================================\n", flush=True)
def _content_to_text(content: Any) -> str:
"""Collapse a message's ``content`` (string or multimodal blocks) to text."""
if isinstance(content, str):
@@ -499,36 +426,6 @@ class PI052TextTokenizerStep(ProcessorStep):
)
]
if _DUMP_BUDGET > 0:
if _is_batched_messages(messages):
msgs_iter = messages
streams_iter = complementary.get("message_streams") or [[] for _ in messages]
targets_iter = complementary.get("target_message_indices") or [[] for _ in messages]
else:
msgs_iter = [messages]
streams_iter = [list(complementary.get("message_streams") or [])]
targets_iter = [list(complementary.get("target_message_indices") or [])]
for msg, streams, targets, (ids, attn, labels, predict_action, prompt) in zip(
msgs_iter, streams_iter, targets_iter, encoded, strict=False
):
target_set = {int(i) for i in targets}
annotated_msgs = [
{
**m,
"stream": streams[i] if i < len(streams) else None,
"target": True if i in target_set else None,
}
for i, m in enumerate(msg)
]
_dump_recipe_sample(
messages=annotated_msgs,
prompt_text=prompt,
token_ids=ids.tolist(),
labels=labels.tolist(),
predict_actions=bool(predict_action.item()),
tokenizer=tokenizer,
)
obs = dict(transition.get(TransitionKey.OBSERVATION) or {})
obs[OBS_LANGUAGE_TOKENS] = torch.stack([ids for ids, _, _, _, _ in encoded])
obs[OBS_LANGUAGE_ATTENTION_MASK] = torch.stack([attn for _, attn, _, _, _ in encoded])
+1 -67
View File
@@ -319,47 +319,6 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
return p.parse_args(argv)
def _log_obs_tensors_once(label: str, obs: Any, flag: dict) -> None:
"""Print shape / dtype / per-channel stats of every observation tensor
going into the policy, exactly once per provider lifetime.
Used to bisect train/inference mismatches: if the dry-run path
and the robot path produce identifiably different tensors here
(e.g. one is batched twice, one has a different range, one is on
a different device), the LM head's collapse on the live robot is
a tensor-shape bug, not a distribution-shift problem. If the
tensors *do* match byte-for-byte and the head still collapses,
only then is the scene-content OOD hypothesis the right one.
"""
if flag.get("done") or not isinstance(obs, dict):
return
flag["done"] = True
import torch as _torch # noqa: PLC0415
for k, v in obs.items():
if not isinstance(k, str) or not k.startswith("observation."):
continue
if isinstance(v, _torch.Tensor):
try:
stats = (
f"min={float(v.min()):.4f} max={float(v.max()):.4f} "
f"mean={float(v.mean()):.4f} std={float(v.float().std()):.4f}"
)
except Exception: # noqa: BLE001
stats = "(stats unavailable)"
logger.warning(
"obs[%s] %-30s shape=%s dtype=%s device=%s %s",
label,
k,
tuple(v.shape),
v.dtype,
v.device,
stats,
)
else:
logger.warning("obs[%s] %-30s type=%s value=%r", label, k, type(v).__name__, v)
# Columns the runtime supplies itself via its own message stream — strip
# them so ``RenderMessagesStep`` / ``PI052TextTokenizerStep`` are no-ops.
_RUNTIME_OWNED_LANGUAGE_COLS = ("language_persistent", "language_events")
@@ -451,8 +410,6 @@ def _build_observation_provider(
so ``RenderMessagesStep`` and ``PI052TextTokenizerStep`` are
no-ops; the runtime supplies its own messages from current state.
"""
import torch # noqa: PLC0415
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415
ds = LeRobotDataset(dataset_repo_id, episodes=[episode])
@@ -485,7 +442,6 @@ def _build_observation_provider(
)
state = {"cursor": max(0, min(start_frame, len(ds) - 1))}
_logged = {"done": False}
def _provider() -> dict | None:
idx = state["cursor"]
@@ -498,26 +454,7 @@ def _build_observation_provider(
if preprocessor is not None:
sample = preprocessor(sample)
_log_obs_tensors_once("dry-run", sample, _logged)
observation = _select_observation_to_device(sample, device)
# Defensive: if something further upstream forgot the batch
# dim, add it now so downstream Tensor ops don't crash.
# ``add_batch_dim`` already ran inside the preprocessor; an
# unbatched tensor at this point means a step somewhere
# produced an unbatched output. Best-effort fix. (Robot path
# gets a batch dim from ``build_inference_frame`` / the
# generic fallback, so it doesn't need this.)
for k, v in list(observation.items()):
if (
isinstance(v, torch.Tensor)
and v.ndim > 0
and v.shape[0] != 1
and v.ndim < 4
and "image" not in k
):
observation[k] = v.unsqueeze(0)
return observation
return _select_observation_to_device(sample, device)
return _provider
@@ -839,7 +776,6 @@ def _build_robot_observation_provider(
# head's distribution at position 0 collapses to its dominant
# mode (a memorised ``\n``-only run in this checkpoint).
_resize_logged = {"done": False}
_obs_logged = {"done": False}
target_image_shapes: dict[str, tuple[int, int]] = {}
if ds_features:
for fkey, fmeta in ds_features.items():
@@ -974,8 +910,6 @@ def _build_robot_observation_provider(
return None
obs_tensors = processed if isinstance(processed, dict) else {}
_log_obs_tensors_once("robot", obs_tensors, _obs_logged)
return _select_observation_to_device(obs_tensors, torch_device)
return _provider