mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
fix(smolvla2): reject gibberish high-level generations
Memorised models can collapse to dominant-mode outputs (the JSON-token salad ``":":":":...`` from VQA training) when the prompt drifts even slightly from training distribution. Without a guard, that gibberish lands in ``current_subtask`` / ``current_plan`` / ``current_memory``, which feeds the next tick's prompt and cascades into worse outputs. The user observed exactly this: a clean run followed by a tick that wrote ``" " "`` into plan and memory, then slow recovery several ticks later. Add ``_looks_like_gibberish`` heuristic (alpha density, repeating chars, JSON-prefix sniff) and apply it before mutating state in ``HighLevelSubtaskFwd`` / ``MemoryUpdateFwd`` / ``UserInterjectionFwd``. Bad generations are logged inline (``[info] subtask gen rejected (gibberish): "":":":..."``) so the user can see what was dropped, but the state stays at its last-known-good value (typically the dataset bootstrap) instead of being polluted. VQA path is intentionally exempt — its training targets *are* JSON-shaped, so the heuristic would false-positive on them. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -271,6 +271,9 @@ class HighLevelSubtaskFwd(InferenceStep):
|
|||||||
msg = _generate_with_policy(
|
msg = _generate_with_policy(
|
||||||
self.policy, ctx, observation=observation, state=state, label="subtask gen"
|
self.policy, ctx, observation=observation, state=state, label="subtask gen"
|
||||||
)
|
)
|
||||||
|
if msg and _looks_like_gibberish(msg):
|
||||||
|
push_log(state, f" [info] subtask gen rejected (gibberish): {msg[:60]!r}")
|
||||||
|
return None
|
||||||
if msg:
|
if msg:
|
||||||
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
|
changed = set_if_changed(state, "current_subtask", msg, label="subtask")
|
||||||
if changed:
|
if changed:
|
||||||
@@ -307,6 +310,9 @@ class MemoryUpdateFwd(InferenceStep):
|
|||||||
new_memory = _generate_with_policy(
|
new_memory = _generate_with_policy(
|
||||||
self.policy, ctx, observation=observation, state=state, label="memory gen"
|
self.policy, ctx, observation=observation, state=state, label="memory gen"
|
||||||
)
|
)
|
||||||
|
if new_memory and _looks_like_gibberish(new_memory):
|
||||||
|
push_log(state, f" [info] memory gen rejected (gibberish): {new_memory[:60]!r}")
|
||||||
|
return None
|
||||||
if new_memory:
|
if new_memory:
|
||||||
set_if_changed(state, "current_memory", new_memory, label="memory")
|
set_if_changed(state, "current_memory", new_memory, label="memory")
|
||||||
return None
|
return None
|
||||||
@@ -340,11 +346,16 @@ class UserInterjectionFwd(InferenceStep):
|
|||||||
if not out:
|
if not out:
|
||||||
push_log(state, " [info] plan/say gen produced no text this tick")
|
push_log(state, " [info] plan/say gen produced no text this tick")
|
||||||
return None
|
return None
|
||||||
|
if _looks_like_gibberish(out):
|
||||||
|
push_log(state, f" [info] plan/say gen rejected (gibberish): {out[:60]!r}")
|
||||||
|
return None
|
||||||
# Heuristic split: model is trained to emit one assistant turn
|
# Heuristic split: model is trained to emit one assistant turn
|
||||||
# carrying both plan text AND a `say` tool call. Look for a
|
# carrying both plan text AND a `say` tool call. Look for a
|
||||||
# "<say>...</say>" or "say(...)" marker; fall back to whole
|
# "<say>...</say>" or "say(...)" marker; fall back to whole
|
||||||
# text → plan, no speech.
|
# text → plan, no speech.
|
||||||
plan_text, speech_text = _split_plan_and_say(out)
|
plan_text, speech_text = _split_plan_and_say(out)
|
||||||
|
if plan_text and _looks_like_gibberish(plan_text):
|
||||||
|
plan_text = ""
|
||||||
if plan_text:
|
if plan_text:
|
||||||
set_if_changed(state, "current_plan", plan_text, label="plan")
|
set_if_changed(state, "current_plan", plan_text, label="plan")
|
||||||
if speech_text:
|
if speech_text:
|
||||||
@@ -390,6 +401,9 @@ class AskVQAFwd(InferenceStep):
|
|||||||
answer = _generate_with_policy(
|
answer = _generate_with_policy(
|
||||||
self.policy, ctx, observation=observation, state=state, label="vqa gen"
|
self.policy, ctx, observation=observation, state=state, label="vqa gen"
|
||||||
)
|
)
|
||||||
|
# VQA answers are intentionally JSON-like during training, so
|
||||||
|
# ``_looks_like_gibberish`` would false-positive on them. Keep
|
||||||
|
# the answer as-is — the VQA panel line lets the user judge.
|
||||||
if answer:
|
if answer:
|
||||||
push_log(state, f" vqa: {answer}")
|
push_log(state, f" vqa: {answer}")
|
||||||
state["recent_vqa_query"] = None
|
state["recent_vqa_query"] = None
|
||||||
@@ -432,6 +446,38 @@ class DispatchToolCalls(InferenceStep):
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _looks_like_gibberish(text: str) -> bool:
|
||||||
|
"""Heuristically detect generation that's clearly off the rails.
|
||||||
|
|
||||||
|
Memorised models can collapse to dominant-mode outputs (often the
|
||||||
|
JSON-token salad ``":":":":...`` from VQA training) when the prompt
|
||||||
|
drifts even slightly from training distribution. If we accept those
|
||||||
|
as new state, they pollute the next tick's prompt and cascade into
|
||||||
|
worse outputs. Reject anything that looks pathological:
|
||||||
|
|
||||||
|
* empty / whitespace-only
|
||||||
|
* mostly punctuation (``"``, ``:``, ``,``)
|
||||||
|
* a single character repeated past the threshold
|
||||||
|
* starts with ``":"`` and contains no letters
|
||||||
|
|
||||||
|
The thresholds are intentionally lenient — a real subtask like
|
||||||
|
``"close the gripper"`` has ~70%+ alpha characters, while gibberish
|
||||||
|
like ``":":":"`` has ~0%.
|
||||||
|
"""
|
||||||
|
if not text or not text.strip():
|
||||||
|
return True
|
||||||
|
stripped = text.strip()
|
||||||
|
alpha = sum(1 for c in stripped if c.isalpha())
|
||||||
|
if alpha < max(3, len(stripped) // 8):
|
||||||
|
return True
|
||||||
|
if stripped.startswith('":') and stripped.count('"') > stripped.count(" "):
|
||||||
|
return True
|
||||||
|
# Single repeating char: e.g. ``""""""``
|
||||||
|
if len(set(stripped)) <= 2 and len(stripped) > 4:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _control_context_messages(
|
def _control_context_messages(
|
||||||
state: dict[str, Any],
|
state: dict[str, Any],
|
||||||
*,
|
*,
|
||||||
|
|||||||
Reference in New Issue
Block a user