feat(inference): repetition_penalty + no_repeat_ngram_size for select_message

Under-trained LM heads (small dataset + a few thousand steps on a
chat-pretrained backbone) collapse into n-gram loops under greedy
decoding — observed in real-robot run as

    "the robot arm extends and retracts and retracts from the
     beige surface and retracts from the surface"

repeating the same trigram across the whole 256-token budget.

Added the two standard HF generation knobs to
``SmolVLA2Policy.select_message``:

* ``repetition_penalty`` (1.0 = off) — divides positive logits /
  multiplies negative logits for already-emitted token ids.
* ``no_repeat_ngram_size`` (0 = off) — hard-bans any token that
  would complete an n-gram already present in the generated suffix.
  Implemented via a small ``_ngram_banned_ids`` helper that mirrors
  HF's ``_get_ngrams`` semantics.

Wired through ``_generate_with_policy`` to all four call sites
(subtask, memory, plan/say, vqa) and exposed as
``--text_repetition_penalty=1.2-1.5`` and
``--text_no_repeat_ngram_size=3`` on the runtime CLI.

Empirically ``--text_no_repeat_ngram_size=3`` alone usually breaks
the trigram-loop failure mode without distorting the next-token
distribution; combine with ``--text_repetition_penalty=1.2`` for
heavier collapses.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-14 13:35:46 +02:00
parent 95033733fc
commit db2972fb6c
4 changed files with 1615 additions and 137 deletions
@@ -396,6 +396,8 @@ class HighLevelSubtaskFwd(InferenceStep):
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
temperature=float(state.get("text_gen_temperature") or 0.0),
top_p=float(state.get("text_gen_top_p") or 1.0),
repetition_penalty=float(state.get("text_gen_repetition_penalty") or 1.0),
no_repeat_ngram_size=int(state.get("text_gen_no_repeat_ngram_size") or 0),
)
# Diagnostics: surface what the model is *actually* producing
# at chunk boundaries, even when the output gets rejected or
@@ -488,6 +490,11 @@ class MemoryUpdateFwd(InferenceStep):
observation=observation,
state=state,
label="memory gen",
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
temperature=float(state.get("text_gen_temperature") or 0.0),
top_p=float(state.get("text_gen_top_p") or 1.0),
repetition_penalty=float(state.get("text_gen_repetition_penalty") or 1.0),
no_repeat_ngram_size=int(state.get("text_gen_no_repeat_ngram_size") or 0),
)
state["last_memory_raw"] = new_memory or ""
if new_memory and _looks_like_gibberish(new_memory):
@@ -531,6 +538,11 @@ class UserInterjectionFwd(InferenceStep):
observation=observation,
state=state,
label="plan/say gen",
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
temperature=float(state.get("text_gen_temperature") or 0.0),
top_p=float(state.get("text_gen_top_p") or 1.0),
repetition_penalty=float(state.get("text_gen_repetition_penalty") or 1.0),
no_repeat_ngram_size=int(state.get("text_gen_no_repeat_ngram_size") or 0),
)
if not out:
# Don't log every empty completion — happens repeatedly on
@@ -600,6 +612,11 @@ class AskVQAFwd(InferenceStep):
observation=observation,
state=state,
label="vqa gen",
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
temperature=float(state.get("text_gen_temperature") or 0.0),
top_p=float(state.get("text_gen_top_p") or 1.0),
repetition_penalty=float(state.get("text_gen_repetition_penalty") or 1.0),
no_repeat_ngram_size=int(state.get("text_gen_no_repeat_ngram_size") or 0),
)
# VQA answers are intentionally JSON-like during training, so
# ``_looks_like_gibberish`` would false-positive on them. Keep
@@ -829,6 +846,8 @@ def _generate_with_policy(
min_new_tokens: int = 0,
temperature: float = 0.0,
top_p: float = 1.0,
repetition_penalty: float = 1.0,
no_repeat_ngram_size: int = 0,
) -> str:
"""Drive ``policy.select_message`` with a chat batch (and optional obs).
@@ -869,6 +888,8 @@ def _generate_with_policy(
min_new_tokens=min_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
no_repeat_ngram_size=no_repeat_ngram_size,
)
except Exception as exc: # noqa: BLE001
logger.warning("%s failed: %s", label, exc, exc_info=logger.isEnabledFor(logging.DEBUG))
@@ -85,6 +85,25 @@ def _locate_lang_range(prefix_att_masks: Tensor, num_lang: int) -> tuple[int, in
return lang_start, lang_end
def _ngram_banned_ids(generated: list[int], n: int) -> set[int]:
"""Standard ``no_repeat_ngram_size`` ban list.
For every ``(n-1)``-gram suffix in ``generated`` that already
appeared earlier in ``generated``, the next-token id that
*completed* it is banned — emitting it would produce a repeated
``n``-gram. Matches HF's ``_get_ngrams`` semantics.
"""
if n <= 0 or len(generated) < n:
return set()
cur_prefix = tuple(generated[-(n - 1):]) if n > 1 else ()
banned: set[int] = set()
for i in range(len(generated) - n + 1):
ngram = tuple(generated[i : i + n])
if ngram[:-1] == cur_prefix:
banned.add(ngram[-1])
return banned
def _shifted_ce(logits: Tensor, text_labels: Tensor) -> Tensor:
"""Next-token CE: hidden at t predicts label at t+1, ignore_index=-100."""
num_lang = logits.shape[1]
@@ -434,6 +453,8 @@ class SmolVLA2Policy(SmolVLAPolicy):
eos_token_id: int | None = None,
temperature: float = 0.0,
top_p: float = 1.0,
repetition_penalty: float = 1.0,
no_repeat_ngram_size: int = 0,
tokenizer: Any = None,
) -> str:
"""Generate text continuation from the chat-templated prompt.
@@ -566,6 +587,28 @@ class SmolVLA2Policy(SmolVLAPolicy):
if special_ids_set and len(generated) < min_new_tokens:
for sid in special_ids_set:
logits_step[..., sid] = float("-inf")
# Repetition controls — break the "extends and retracts and
# retracts" loops that undertrained LM heads fall into under
# greedy decoding. Standard HF generation behaviour:
# ``repetition_penalty`` divides positive logits / multiplies
# negative logits for already-emitted token ids;
# ``no_repeat_ngram_size`` hard-bans any token that would
# complete an n-gram already seen in the generated suffix.
if repetition_penalty != 1.0 and generated:
seen_ids = list(set(generated))
logit_row = logits_step[0]
vals = logit_row[seen_ids]
vals = torch.where(
vals > 0,
vals / repetition_penalty,
vals * repetition_penalty,
)
logit_row[seen_ids] = vals
if no_repeat_ngram_size > 0 and len(generated) >= no_repeat_ngram_size:
banned = _ngram_banned_ids(generated, no_repeat_ngram_size)
if banned:
for bad in banned:
logits_step[..., bad] = float("-inf")
next_ids = self._sample_next_token(logits_step, temperature, top_p)
tok_id = int(next_ids[0].item())
generated.append(tok_id)
@@ -288,6 +288,30 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
default=1.0,
help="Nucleus filtering for high-level text gen.",
)
p.add_argument(
"--text_repetition_penalty",
type=float,
default=1.0,
help=(
"Per-token repetition penalty (HF-style: divides positive "
"logits / multiplies negative logits for already-emitted "
"ids). ``1.0`` = off. Try ``1.2``-``1.5`` on under-trained "
"checkpoints that fall into n-gram loops like "
"``\"extends and retracts and retracts and retracts\"``."
),
)
p.add_argument(
"--text_no_repeat_ngram_size",
type=int,
default=0,
help=(
"Hard-ban any token that would complete an n-gram already "
"present in the generated suffix. ``0`` = off. ``3`` kills "
"most LM-head repetition collapses without distorting the "
"next-token distribution; ``2`` is more aggressive but can "
"block legit ``\"the X\"`` patterns."
),
)
p.add_argument("-v", "--verbose", action="store_true", help="Enable DEBUG logging.")
return p.parse_args(argv)
@@ -1333,6 +1357,12 @@ def main(argv: list[str] | None = None) -> int:
runtime.state["text_gen_min_new_tokens"] = int(getattr(args, "text_min_new_tokens", 0) or 0)
runtime.state["text_gen_temperature"] = float(getattr(args, "text_temperature", 0.0) or 0.0)
runtime.state["text_gen_top_p"] = float(getattr(args, "text_top_p", 1.0) or 1.0)
runtime.state["text_gen_repetition_penalty"] = float(
getattr(args, "text_repetition_penalty", 1.0) or 1.0
)
runtime.state["text_gen_no_repeat_ngram_size"] = int(
getattr(args, "text_no_repeat_ngram_size", 0) or 0
)
if args.task:
runtime.set_task(args.task)
# Seed plan/memory/subtask so the first prompt the runtime builds
Generated
+1521 -137
View File
File diff suppressed because it is too large Load Diff