feat(smolvla2): per-component prompt dropout + augmented training script

Two complementary regularisers to attack the
``text_loss=6e-6 = memorised one dataset`` failure mode that's
making the model collapse on real-robot input:

1. **Per-component prompt dropout** (Pi0.7 §V.E / plan's
   ``feat/pi05-prompt-dropout`` follow-up).
   ``SmolVLA2ChatTokenizerStep`` gains
   ``plan_dropout_prob`` / ``memory_dropout_prob`` /
   ``subtask_dropout_prob`` knobs (default 0.0 — opt-in). At training,
   non-target messages whose rendered content starts with
   ``Plan:`` / ``Memory:`` / ``Current subtask:`` etc. are dropped
   with their respective probability before tokenisation, with a
   deterministic per-sample RNG keyed off the dataset ``index``.
   ``target_message_indices`` is re-mapped so the supervision still
   lands on the right turn. Forces the model to handle missing
   plan/memory/subtask context — directly attacks the real-robot
   collapse where a stale or empty plan field puts the prompt OOD.

   Surfaced on ``SmolVLA2Config`` as three floats so they're
   ``--policy.<knob>=<value>``-controllable from the train CLI;
   plumbed through ``make_smolvla2_pre_post_processors``.

2. **Image augmentation** is already wired in lerobot via
   ``--dataset.image_transforms.enable=true`` (torchvision v2
   ColorJitter + SharpnessJitter + RandomAffine, default 3 of 6
   sampled per frame). No code change needed — just a CLI flag.

``examples/training/smolvla2_hirobot.slurm`` shows the full
training command with both enabled. Drop-in replacement for the
ad-hoc SLURM script Pepijn was using locally; same args, plus the
three dropout probs and the image-transforms flag.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-12 15:52:32 +02:00
parent c36de3a3e8
commit 01e2228b24
4 changed files with 241 additions and 3 deletions
@@ -70,6 +70,22 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
padding: str = "longest"
padding_side: str = "right"
tools: list[dict[str, Any]] | None = None
# --- Per-component prompt dropout (Pi0.7 §V.E, plan follow-up
# ``feat/pi05-prompt-dropout``). At training, drop non-target
# messages whose content was substituted from the named recipe
# binding with the given probability. Forces the model to handle
# missing context at inference — directly attacks the memorisation
# collapse where ``current_subtask=""`` puts the prompt OOD. All
# default to 0.0 (no dropout) so behaviour is identical until
# explicitly opted in via the training config.
plan_dropout_prob: float = 0.0
memory_dropout_prob: float = 0.0
subtask_dropout_prob: float = 0.0
interjection_dropout_prob: float = 0.0
# Optional seed for the per-sample RNG. ``None`` ⇒ use
# ``sample_idx`` derived from the transition (when present), so
# dropout is reproducible across runs but varies per sample.
dropout_seed: int | None = None
def __post_init__(self) -> None:
# Lazy: don't load the tokenizer until the step actually runs,
@@ -101,19 +117,38 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
tokenizer = self._get_tokenizer()
# Pull a sample_idx for the dropout RNG. ``index`` is the
# canonical per-frame key on ``LeRobotDataset`` samples and
# flows through into ``COMPLEMENTARY_DATA`` unchanged. When
# absent (e.g. inference) we fall back to 0 which is harmless
# because the dropout probs are also 0 at inference time.
sample_idx_raw = comp.get("index")
if hasattr(sample_idx_raw, "item"):
try:
sample_idx_raw = sample_idx_raw.item()
except Exception: # noqa: BLE001
pass
if _is_batched_messages(messages):
indices_iter = (
sample_idx_raw
if isinstance(sample_idx_raw, (list, tuple))
else [sample_idx_raw] * len(messages)
)
encoded = [
self._encode_messages(
tokenizer,
msg,
list(streams),
sorted(int(i) for i in indices),
sorted(int(i) for i in tgt_indices),
sample_idx=int(s_idx) if s_idx is not None else None,
)
for msg, streams, indices in zip(
for msg, streams, tgt_indices, s_idx in zip(
messages,
comp.get("message_streams") or [[] for _ in messages],
comp.get("target_message_indices") or [[] for _ in messages],
strict=True,
indices_iter,
strict=False,
)
]
else:
@@ -123,6 +158,7 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
messages,
list(comp.get("message_streams") or []),
sorted(int(i) for i in (comp.get("target_message_indices") or [])),
sample_idx=int(sample_idx_raw) if sample_idx_raw is not None else None,
)
]
@@ -190,7 +226,15 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
messages: list[dict[str, Any]],
message_streams: list[str | None],
target_indices: list[int],
sample_idx: int | None = None,
) -> tuple[list[int], list[int], bool]:
# Apply per-component prompt dropout *before* tokenisation, so
# the dropped messages don't contribute tokens or label-mask
# positions at all. Re-maps ``target_indices`` to account for
# removed messages.
messages, target_indices = self._apply_prompt_dropout(
messages, target_indices, sample_idx
)
text_messages = [_strip_lerobot_blocks(m) for m in messages]
full_ids = tokenizer.apply_chat_template(
@@ -231,6 +275,62 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
)
return [int(i) for i in full_ids], labels, predict_actions
def _apply_prompt_dropout(
self,
messages: list[dict[str, Any]],
target_indices: list[int],
sample_idx: int | None,
) -> tuple[list[dict[str, Any]], list[int]]:
"""Probabilistically drop non-target context messages.
Heuristic content sniffing — matches the prefix strings that
``smolvla2_hirobot.yaml``'s recipes use when injecting plan /
memory / subtask / interjection content. Anything else is
kept unchanged. Target messages are never dropped (we still
need their tokens for supervision).
Returns ``(new_messages, new_target_indices)`` where the
indices are re-mapped to point at the same target turns in
the trimmed list.
"""
probs = {
"plan": float(self.plan_dropout_prob or 0.0),
"memory": float(self.memory_dropout_prob or 0.0),
"subtask": float(self.subtask_dropout_prob or 0.0),
"interjection": float(self.interjection_dropout_prob or 0.0),
}
if not any(p > 0.0 for p in probs.values()):
return messages, target_indices
# Deterministic per-sample RNG so dropout is reproducible
# across runs (matters for debugging / repro) but varies
# frame-to-frame.
import random # noqa: PLC0415
seed_int = self.dropout_seed if self.dropout_seed is not None else (sample_idx or 0)
rng = random.Random(int(seed_int) & 0xFFFFFFFF)
target_set = set(target_indices)
keep_flags: list[bool] = []
for i, msg in enumerate(messages):
if i in target_set:
keep_flags.append(True)
continue
kind = _classify_message_for_dropout(msg)
if kind and rng.random() < probs.get(kind, 0.0):
keep_flags.append(False)
else:
keep_flags.append(True)
new_messages = [m for m, keep in zip(messages, keep_flags) if keep]
# Re-map target_indices: each old index drops by the count of
# falsy flags before it.
new_target_indices: list[int] = []
for old_idx in target_indices:
dropped_before = sum(1 for k in keep_flags[:old_idx] if not k)
new_target_indices.append(old_idx - dropped_before)
return new_messages, sorted(new_target_indices)
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
@@ -296,6 +396,39 @@ def _is_batched_messages(messages: Any) -> bool:
return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list)
def _classify_message_for_dropout(message: dict[str, Any]) -> str | None:
"""Best-effort classification of which recipe binding contributed
to this message, used for per-component dropout.
The canonical recipe authors plan/memory/subtask injections with
distinctive prefix strings in the rendered content. Matching on
those prefixes is brittle if a future recipe author uses
different wording — but it's also localised to one place and
only affects the dropout fraction (never the actual semantics).
Returns ``None`` for messages we don't recognise; those are
always kept.
"""
content = message.get("content")
if isinstance(content, list):
text_parts: list[str] = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
t = block.get("text")
if isinstance(t, str):
text_parts.append(t)
content = "\n".join(text_parts)
if not isinstance(content, str):
return None
head = content.lstrip().lower()
if head.startswith("plan:") or head.startswith("previous plan"):
return "plan"
if head.startswith("memory:") or head.startswith("previous memory"):
return "memory"
if head.startswith("current subtask") or head.startswith("completed subtask"):
return "subtask"
return None
def _as_token_ids(value: Any) -> list[int]:
if isinstance(value, dict) or (hasattr(value, "keys") and "input_ids" in value.keys()):
value = value["input_ids"]
@@ -84,6 +84,24 @@ class SmolVLA2Config(SmolVLAConfig):
effectively reduces SmolVLA2 back to SmolVLA's flow-only training,
which is occasionally useful for ablations."""
# Per-component prompt dropout (Pi0.7 §V.E) ---------------------------
# At training, randomly drop non-target context messages whose
# content was substituted from the named recipe binding. Forces
# the model to handle missing context — directly attacks the
# memorisation collapse where a stale or missing plan/memory at
# inference puts the prompt out-of-distribution and the LM head
# falls back to dominant-mode fragments. All default to 0.0 so
# behaviour is identical until explicitly enabled.
plan_dropout_prob: float = 0.0
"""Drop messages whose content starts with ``Plan:`` or ``Previous plan``
with this probability per sample."""
memory_dropout_prob: float = 0.0
"""Drop messages whose content starts with ``Memory:`` or ``Previous memory``
with this probability per sample."""
subtask_dropout_prob: float = 0.0
"""Drop messages whose content starts with ``Current subtask`` or
``Completed subtask`` with this probability per sample."""
def __post_init__(self) -> None:
super().__post_init__()
# Backbone needs gradients flowing through its text path when the
@@ -83,6 +83,9 @@ def make_smolvla2_pre_post_processors(
tokenizer_name=config.vlm_model_name,
max_length=config.tokenizer_max_length,
padding=config.pad_language_to,
plan_dropout_prob=getattr(config, "plan_dropout_prob", 0.0),
memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0),
subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0),
),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(