mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
feat(inference): repetition_penalty + no_repeat_ngram_size for select_message
Under-trained LM heads (small dataset + a few thousand steps on a
chat-pretrained backbone) collapse into n-gram loops under greedy
decoding — observed in real-robot run as
"the robot arm extends and retracts and retracts from the
beige surface and retracts from the surface"
repeating the same trigram across the whole 256-token budget.
Added the two standard HF generation knobs to
``SmolVLA2Policy.select_message``:
* ``repetition_penalty`` (1.0 = off) — divides positive logits /
multiplies negative logits for already-emitted token ids.
* ``no_repeat_ngram_size`` (0 = off) — hard-bans any token that
would complete an n-gram already present in the generated suffix.
Implemented via a small ``_ngram_banned_ids`` helper that mirrors
HF's ``_get_ngrams`` semantics.
Wired through ``_generate_with_policy`` to all four call sites
(subtask, memory, plan/say, vqa) and exposed as
``--text_repetition_penalty=1.2-1.5`` and
``--text_no_repeat_ngram_size=3`` on the runtime CLI.
Empirically ``--text_no_repeat_ngram_size=3`` alone usually breaks
the trigram-loop failure mode without distorting the next-token
distribution; combine with ``--text_repetition_penalty=1.2`` for
heavier collapses.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -396,6 +396,8 @@ class HighLevelSubtaskFwd(InferenceStep):
|
||||
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
|
||||
temperature=float(state.get("text_gen_temperature") or 0.0),
|
||||
top_p=float(state.get("text_gen_top_p") or 1.0),
|
||||
repetition_penalty=float(state.get("text_gen_repetition_penalty") or 1.0),
|
||||
no_repeat_ngram_size=int(state.get("text_gen_no_repeat_ngram_size") or 0),
|
||||
)
|
||||
# Diagnostics: surface what the model is *actually* producing
|
||||
# at chunk boundaries, even when the output gets rejected or
|
||||
@@ -488,6 +490,11 @@ class MemoryUpdateFwd(InferenceStep):
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="memory gen",
|
||||
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
|
||||
temperature=float(state.get("text_gen_temperature") or 0.0),
|
||||
top_p=float(state.get("text_gen_top_p") or 1.0),
|
||||
repetition_penalty=float(state.get("text_gen_repetition_penalty") or 1.0),
|
||||
no_repeat_ngram_size=int(state.get("text_gen_no_repeat_ngram_size") or 0),
|
||||
)
|
||||
state["last_memory_raw"] = new_memory or ""
|
||||
if new_memory and _looks_like_gibberish(new_memory):
|
||||
@@ -531,6 +538,11 @@ class UserInterjectionFwd(InferenceStep):
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="plan/say gen",
|
||||
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
|
||||
temperature=float(state.get("text_gen_temperature") or 0.0),
|
||||
top_p=float(state.get("text_gen_top_p") or 1.0),
|
||||
repetition_penalty=float(state.get("text_gen_repetition_penalty") or 1.0),
|
||||
no_repeat_ngram_size=int(state.get("text_gen_no_repeat_ngram_size") or 0),
|
||||
)
|
||||
if not out:
|
||||
# Don't log every empty completion — happens repeatedly on
|
||||
@@ -600,6 +612,11 @@ class AskVQAFwd(InferenceStep):
|
||||
observation=observation,
|
||||
state=state,
|
||||
label="vqa gen",
|
||||
min_new_tokens=int(state.get("text_gen_min_new_tokens") or 0),
|
||||
temperature=float(state.get("text_gen_temperature") or 0.0),
|
||||
top_p=float(state.get("text_gen_top_p") or 1.0),
|
||||
repetition_penalty=float(state.get("text_gen_repetition_penalty") or 1.0),
|
||||
no_repeat_ngram_size=int(state.get("text_gen_no_repeat_ngram_size") or 0),
|
||||
)
|
||||
# VQA answers are intentionally JSON-like during training, so
|
||||
# ``_looks_like_gibberish`` would false-positive on them. Keep
|
||||
@@ -829,6 +846,8 @@ def _generate_with_policy(
|
||||
min_new_tokens: int = 0,
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
repetition_penalty: float = 1.0,
|
||||
no_repeat_ngram_size: int = 0,
|
||||
) -> str:
|
||||
"""Drive ``policy.select_message`` with a chat batch (and optional obs).
|
||||
|
||||
@@ -869,6 +888,8 @@ def _generate_with_policy(
|
||||
min_new_tokens=min_new_tokens,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
no_repeat_ngram_size=no_repeat_ngram_size,
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("%s failed: %s", label, exc, exc_info=logger.isEnabledFor(logging.DEBUG))
|
||||
|
||||
@@ -85,6 +85,25 @@ def _locate_lang_range(prefix_att_masks: Tensor, num_lang: int) -> tuple[int, in
|
||||
return lang_start, lang_end
|
||||
|
||||
|
||||
def _ngram_banned_ids(generated: list[int], n: int) -> set[int]:
|
||||
"""Standard ``no_repeat_ngram_size`` ban list.
|
||||
|
||||
For every ``(n-1)``-gram suffix in ``generated`` that already
|
||||
appeared earlier in ``generated``, the next-token id that
|
||||
*completed* it is banned — emitting it would produce a repeated
|
||||
``n``-gram. Matches HF's ``_get_ngrams`` semantics.
|
||||
"""
|
||||
if n <= 0 or len(generated) < n:
|
||||
return set()
|
||||
cur_prefix = tuple(generated[-(n - 1):]) if n > 1 else ()
|
||||
banned: set[int] = set()
|
||||
for i in range(len(generated) - n + 1):
|
||||
ngram = tuple(generated[i : i + n])
|
||||
if ngram[:-1] == cur_prefix:
|
||||
banned.add(ngram[-1])
|
||||
return banned
|
||||
|
||||
|
||||
def _shifted_ce(logits: Tensor, text_labels: Tensor) -> Tensor:
|
||||
"""Next-token CE: hidden at t predicts label at t+1, ignore_index=-100."""
|
||||
num_lang = logits.shape[1]
|
||||
@@ -434,6 +453,8 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
eos_token_id: int | None = None,
|
||||
temperature: float = 0.0,
|
||||
top_p: float = 1.0,
|
||||
repetition_penalty: float = 1.0,
|
||||
no_repeat_ngram_size: int = 0,
|
||||
tokenizer: Any = None,
|
||||
) -> str:
|
||||
"""Generate text continuation from the chat-templated prompt.
|
||||
@@ -566,6 +587,28 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
||||
if special_ids_set and len(generated) < min_new_tokens:
|
||||
for sid in special_ids_set:
|
||||
logits_step[..., sid] = float("-inf")
|
||||
# Repetition controls — break the "extends and retracts and
|
||||
# retracts" loops that undertrained LM heads fall into under
|
||||
# greedy decoding. Standard HF generation behaviour:
|
||||
# ``repetition_penalty`` divides positive logits / multiplies
|
||||
# negative logits for already-emitted token ids;
|
||||
# ``no_repeat_ngram_size`` hard-bans any token that would
|
||||
# complete an n-gram already seen in the generated suffix.
|
||||
if repetition_penalty != 1.0 and generated:
|
||||
seen_ids = list(set(generated))
|
||||
logit_row = logits_step[0]
|
||||
vals = logit_row[seen_ids]
|
||||
vals = torch.where(
|
||||
vals > 0,
|
||||
vals / repetition_penalty,
|
||||
vals * repetition_penalty,
|
||||
)
|
||||
logit_row[seen_ids] = vals
|
||||
if no_repeat_ngram_size > 0 and len(generated) >= no_repeat_ngram_size:
|
||||
banned = _ngram_banned_ids(generated, no_repeat_ngram_size)
|
||||
if banned:
|
||||
for bad in banned:
|
||||
logits_step[..., bad] = float("-inf")
|
||||
next_ids = self._sample_next_token(logits_step, temperature, top_p)
|
||||
tok_id = int(next_ids[0].item())
|
||||
generated.append(tok_id)
|
||||
|
||||
@@ -288,6 +288,30 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||
default=1.0,
|
||||
help="Nucleus filtering for high-level text gen.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--text_repetition_penalty",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help=(
|
||||
"Per-token repetition penalty (HF-style: divides positive "
|
||||
"logits / multiplies negative logits for already-emitted "
|
||||
"ids). ``1.0`` = off. Try ``1.2``-``1.5`` on under-trained "
|
||||
"checkpoints that fall into n-gram loops like "
|
||||
"``\"extends and retracts and retracts and retracts\"``."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--text_no_repeat_ngram_size",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Hard-ban any token that would complete an n-gram already "
|
||||
"present in the generated suffix. ``0`` = off. ``3`` kills "
|
||||
"most LM-head repetition collapses without distorting the "
|
||||
"next-token distribution; ``2`` is more aggressive but can "
|
||||
"block legit ``\"the X\"`` patterns."
|
||||
),
|
||||
)
|
||||
p.add_argument("-v", "--verbose", action="store_true", help="Enable DEBUG logging.")
|
||||
return p.parse_args(argv)
|
||||
|
||||
@@ -1333,6 +1357,12 @@ def main(argv: list[str] | None = None) -> int:
|
||||
runtime.state["text_gen_min_new_tokens"] = int(getattr(args, "text_min_new_tokens", 0) or 0)
|
||||
runtime.state["text_gen_temperature"] = float(getattr(args, "text_temperature", 0.0) or 0.0)
|
||||
runtime.state["text_gen_top_p"] = float(getattr(args, "text_top_p", 1.0) or 1.0)
|
||||
runtime.state["text_gen_repetition_penalty"] = float(
|
||||
getattr(args, "text_repetition_penalty", 1.0) or 1.0
|
||||
)
|
||||
runtime.state["text_gen_no_repeat_ngram_size"] = int(
|
||||
getattr(args, "text_no_repeat_ngram_size", 0) or 0
|
||||
)
|
||||
if args.task:
|
||||
runtime.set_task(args.task)
|
||||
# Seed plan/memory/subtask so the first prompt the runtime builds
|
||||
|
||||
Reference in New Issue
Block a user