mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 14:09:47 +00:00
pi052: drop debug scaffolding left over from training/inference bug hunts
Three diagnostic surfaces shipped in PR3 that don't belong in a clean release: * ``LEROBOT_DUMP_RECIPE_SAMPLES`` env-var dump (~70 LOC in text_processor_pi052.py): pretty-prints the next N rendered samples with ``[TGT]...[/TGT]`` markers over supervised spans. One-off training-inspection tool — no production user, never wired into a CLI flag, only useful while iterating on the recipe. Drop the module constants, the ``_is_dump_rank`` / ``_dump_recipe_sample`` helpers, the call site, and the now-unused ``import os``. * ``_log_obs_tensors_once()`` in lerobot_pi052_runtime.py: the docstring literally says "Used to bisect train/inference mismatches" — a debugging artifact from when the LM head was collapsing on the live robot. Logged unconditionally at WARNING level from both the dataset-driven and robot-driven providers, with no ``--verbose`` gate. Drop the function, both call sites, and the ``_logged`` / ``_obs_logged`` flag dicts that fed them. (``_resize_logged`` is kept — it gates the operationally useful camera-size sanity log.) * Defensive ``unsqueeze(0)`` block in the dataset observation provider: papered over an upstream bug where some preprocessor step could produce an unbatched tensor. ``AddBatchDimensionProcessorStep`` is reliable in the current pipeline — pi052 tests still pass with the block removed. If the bug ever resurfaces it should be fixed at the source, not silently re-batched here. Net: -169 LOC. All 30 ``tests/policies/pi052/`` tests pass. The ``<loc>`` token plumbing (``register_paligemma_loc_tokens``, ``_loc_token``, ``suppress_loc_tokens`` runtime gate) is left as-is — it's the actual mechanism for VQA spatial answers, not scaffolding, and the ``suppress_loc_tokens=True`` callers on subtask/memory/ interjection paths and ``=False`` on the VQA path are intentional asymmetric behaviour, not a bug-routing knob. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -37,7 +37,6 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
@@ -52,78 +51,6 @@ from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TO
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Debug helper — when ``LEROBOT_DUMP_RECIPE_SAMPLES=N`` is set, the next N
|
||||
# samples processed (on rank 0) are pretty-printed with ``[TGT]...[/TGT]``
|
||||
# markers over the spans the LM head will be supervised on.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DUMP_BUDGET = int(os.environ.get("LEROBOT_DUMP_RECIPE_SAMPLES", "0"))
|
||||
_DUMPED_SO_FAR = 0
|
||||
|
||||
|
||||
def _is_dump_rank() -> bool:
|
||||
rank = os.environ.get("RANK") or os.environ.get("LOCAL_RANK") or "0"
|
||||
try:
|
||||
return int(rank) == 0
|
||||
except ValueError:
|
||||
return True
|
||||
|
||||
|
||||
def _dump_recipe_sample(
|
||||
*,
|
||||
messages: list[dict[str, Any]],
|
||||
prompt_text: str,
|
||||
token_ids: list[int],
|
||||
labels: list[int],
|
||||
predict_actions: bool,
|
||||
tokenizer: Any,
|
||||
) -> None:
|
||||
"""Pretty-print one rendered sample. Stops once the global budget is hit."""
|
||||
global _DUMPED_SO_FAR
|
||||
if _DUMPED_SO_FAR >= _DUMP_BUDGET or not _is_dump_rank():
|
||||
return
|
||||
_DUMPED_SO_FAR += 1
|
||||
|
||||
parts: list[str] = []
|
||||
i = 0
|
||||
while i < len(labels):
|
||||
if labels[i] == -100:
|
||||
j = i
|
||||
while j < len(labels) and labels[j] == -100:
|
||||
j += 1
|
||||
parts.append(tokenizer.decode(token_ids[i:j], skip_special_tokens=False))
|
||||
i = j
|
||||
else:
|
||||
j = i
|
||||
while j < len(labels) and labels[j] != -100:
|
||||
j += 1
|
||||
tgt_text = tokenizer.decode(token_ids[i:j], skip_special_tokens=False)
|
||||
parts.append(f"[TGT]{tgt_text}[/TGT]")
|
||||
i = j
|
||||
annotated = "".join(parts)
|
||||
|
||||
n_tgt = sum(1 for l in labels if l != -100)
|
||||
print(
|
||||
"\n========== RECIPE SAMPLE DUMP "
|
||||
f"({_DUMPED_SO_FAR}/{_DUMP_BUDGET}) ==========",
|
||||
flush=True,
|
||||
)
|
||||
print(f" predict_actions: {predict_actions}", flush=True)
|
||||
print(f" rendered messages ({len(messages)}):", flush=True)
|
||||
for m in messages:
|
||||
stream = m.get("stream")
|
||||
target = m.get("target")
|
||||
role = m.get("role")
|
||||
content = m.get("content")
|
||||
print(f" - role={role} stream={stream} target={target}", flush=True)
|
||||
print(f" content: {content!r}", flush=True)
|
||||
print(f" rendered prompt:\n {prompt_text!r}", flush=True)
|
||||
print(f" token count: {len(token_ids)} (target tokens: {n_tgt})", flush=True)
|
||||
print(f" decoded (with target markers):\n {annotated}", flush=True)
|
||||
print("==============================================\n", flush=True)
|
||||
|
||||
|
||||
def _content_to_text(content: Any) -> str:
|
||||
"""Collapse a message's ``content`` (string or multimodal blocks) to text."""
|
||||
if isinstance(content, str):
|
||||
@@ -499,36 +426,6 @@ class PI052TextTokenizerStep(ProcessorStep):
|
||||
)
|
||||
]
|
||||
|
||||
if _DUMP_BUDGET > 0:
|
||||
if _is_batched_messages(messages):
|
||||
msgs_iter = messages
|
||||
streams_iter = complementary.get("message_streams") or [[] for _ in messages]
|
||||
targets_iter = complementary.get("target_message_indices") or [[] for _ in messages]
|
||||
else:
|
||||
msgs_iter = [messages]
|
||||
streams_iter = [list(complementary.get("message_streams") or [])]
|
||||
targets_iter = [list(complementary.get("target_message_indices") or [])]
|
||||
for msg, streams, targets, (ids, attn, labels, predict_action, prompt) in zip(
|
||||
msgs_iter, streams_iter, targets_iter, encoded, strict=False
|
||||
):
|
||||
target_set = {int(i) for i in targets}
|
||||
annotated_msgs = [
|
||||
{
|
||||
**m,
|
||||
"stream": streams[i] if i < len(streams) else None,
|
||||
"target": True if i in target_set else None,
|
||||
}
|
||||
for i, m in enumerate(msg)
|
||||
]
|
||||
_dump_recipe_sample(
|
||||
messages=annotated_msgs,
|
||||
prompt_text=prompt,
|
||||
token_ids=ids.tolist(),
|
||||
labels=labels.tolist(),
|
||||
predict_actions=bool(predict_action.item()),
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
|
||||
obs = dict(transition.get(TransitionKey.OBSERVATION) or {})
|
||||
obs[OBS_LANGUAGE_TOKENS] = torch.stack([ids for ids, _, _, _, _ in encoded])
|
||||
obs[OBS_LANGUAGE_ATTENTION_MASK] = torch.stack([attn for _, attn, _, _, _ in encoded])
|
||||
|
||||
@@ -319,47 +319,6 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||
return p.parse_args(argv)
|
||||
|
||||
|
||||
def _log_obs_tensors_once(label: str, obs: Any, flag: dict) -> None:
|
||||
"""Print shape / dtype / per-channel stats of every observation tensor
|
||||
going into the policy, exactly once per provider lifetime.
|
||||
|
||||
Used to bisect train/inference mismatches: if the dry-run path
|
||||
and the robot path produce identifiably different tensors here
|
||||
(e.g. one is batched twice, one has a different range, one is on
|
||||
a different device), the LM head's collapse on the live robot is
|
||||
a tensor-shape bug, not a distribution-shift problem. If the
|
||||
tensors *do* match byte-for-byte and the head still collapses,
|
||||
only then is the scene-content OOD hypothesis the right one.
|
||||
"""
|
||||
if flag.get("done") or not isinstance(obs, dict):
|
||||
return
|
||||
flag["done"] = True
|
||||
import torch as _torch # noqa: PLC0415
|
||||
|
||||
for k, v in obs.items():
|
||||
if not isinstance(k, str) or not k.startswith("observation."):
|
||||
continue
|
||||
if isinstance(v, _torch.Tensor):
|
||||
try:
|
||||
stats = (
|
||||
f"min={float(v.min()):.4f} max={float(v.max()):.4f} "
|
||||
f"mean={float(v.mean()):.4f} std={float(v.float().std()):.4f}"
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
stats = "(stats unavailable)"
|
||||
logger.warning(
|
||||
"obs[%s] %-30s shape=%s dtype=%s device=%s %s",
|
||||
label,
|
||||
k,
|
||||
tuple(v.shape),
|
||||
v.dtype,
|
||||
v.device,
|
||||
stats,
|
||||
)
|
||||
else:
|
||||
logger.warning("obs[%s] %-30s type=%s value=%r", label, k, type(v).__name__, v)
|
||||
|
||||
|
||||
# Columns the runtime supplies itself via its own message stream — strip
|
||||
# them so ``RenderMessagesStep`` / ``PI052TextTokenizerStep`` are no-ops.
|
||||
_RUNTIME_OWNED_LANGUAGE_COLS = ("language_persistent", "language_events")
|
||||
@@ -451,8 +410,6 @@ def _build_observation_provider(
|
||||
so ``RenderMessagesStep`` and ``PI052TextTokenizerStep`` are
|
||||
no-ops; the runtime supplies its own messages from current state.
|
||||
"""
|
||||
import torch # noqa: PLC0415
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415
|
||||
|
||||
ds = LeRobotDataset(dataset_repo_id, episodes=[episode])
|
||||
@@ -485,7 +442,6 @@ def _build_observation_provider(
|
||||
)
|
||||
|
||||
state = {"cursor": max(0, min(start_frame, len(ds) - 1))}
|
||||
_logged = {"done": False}
|
||||
|
||||
def _provider() -> dict | None:
|
||||
idx = state["cursor"]
|
||||
@@ -498,26 +454,7 @@ def _build_observation_provider(
|
||||
if preprocessor is not None:
|
||||
sample = preprocessor(sample)
|
||||
|
||||
_log_obs_tensors_once("dry-run", sample, _logged)
|
||||
|
||||
observation = _select_observation_to_device(sample, device)
|
||||
# Defensive: if something further upstream forgot the batch
|
||||
# dim, add it now so downstream Tensor ops don't crash.
|
||||
# ``add_batch_dim`` already ran inside the preprocessor; an
|
||||
# unbatched tensor at this point means a step somewhere
|
||||
# produced an unbatched output. Best-effort fix. (Robot path
|
||||
# gets a batch dim from ``build_inference_frame`` / the
|
||||
# generic fallback, so it doesn't need this.)
|
||||
for k, v in list(observation.items()):
|
||||
if (
|
||||
isinstance(v, torch.Tensor)
|
||||
and v.ndim > 0
|
||||
and v.shape[0] != 1
|
||||
and v.ndim < 4
|
||||
and "image" not in k
|
||||
):
|
||||
observation[k] = v.unsqueeze(0)
|
||||
return observation
|
||||
return _select_observation_to_device(sample, device)
|
||||
|
||||
return _provider
|
||||
|
||||
@@ -839,7 +776,6 @@ def _build_robot_observation_provider(
|
||||
# head's distribution at position 0 collapses to its dominant
|
||||
# mode (a memorised ``\n``-only run in this checkpoint).
|
||||
_resize_logged = {"done": False}
|
||||
_obs_logged = {"done": False}
|
||||
target_image_shapes: dict[str, tuple[int, int]] = {}
|
||||
if ds_features:
|
||||
for fkey, fmeta in ds_features.items():
|
||||
@@ -974,8 +910,6 @@ def _build_robot_observation_provider(
|
||||
return None
|
||||
obs_tensors = processed if isinstance(processed, dict) else {}
|
||||
|
||||
_log_obs_tensors_once("robot", obs_tensors, _obs_logged)
|
||||
|
||||
return _select_observation_to_device(obs_tensors, torch_device)
|
||||
|
||||
return _provider
|
||||
|
||||
Reference in New Issue
Block a user