From 01e2228b242a094c216aa986efec9ee104ef9d79 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Tue, 12 May 2026 15:52:32 +0200 Subject: [PATCH] feat(smolvla2): per-component prompt dropout + augmented training script MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two complementary regularisers to attack the ``text_loss=6e-6 = memorised one dataset`` failure mode that's making the model collapse on real-robot input: 1. **Per-component prompt dropout** (Pi0.7 §V.E / plan's ``feat/pi05-prompt-dropout`` follow-up). ``SmolVLA2ChatTokenizerStep`` gains ``plan_dropout_prob`` / ``memory_dropout_prob`` / ``subtask_dropout_prob`` knobs (default 0.0 — opt-in). At training, non-target messages whose rendered content starts with ``Plan:`` / ``Memory:`` / ``Current subtask:`` etc. are dropped with their respective probability before tokenisation, with a deterministic per-sample RNG keyed off the dataset ``index``. ``target_message_indices`` is re-mapped so the supervision still lands on the right turn. Forces the model to handle missing plan/memory/subtask context — directly attacks the real-robot collapse where a stale or empty plan field puts the prompt OOD. Surfaced on ``SmolVLA2Config`` as three floats so they're ``--policy.=``-controllable from the train CLI; plumbed through ``make_smolvla2_pre_post_processors``. 2. **Image augmentation** is already wired in lerobot via ``--dataset.image_transforms.enable=true`` (torchvision v2 ColorJitter + SharpnessJitter + RandomAffine, default 3 of 6 sampled per frame). No code change needed — just a CLI flag. ``examples/training/smolvla2_hirobot.slurm`` shows the full training command with both enabled. Drop-in replacement for the ad-hoc SLURM script Pepijn was using locally; same args, plus the three dropout probs and the image-transforms flag. Co-Authored-By: Claude Opus 4.7 (1M context) --- examples/training/smolvla2_hirobot.slurm | 84 +++++++++++ .../smolvla2/chat_processor_smolvla2.py | 139 +++++++++++++++++- .../smolvla2/configuration_smolvla2.py | 18 +++ .../policies/smolvla2/processor_smolvla2.py | 3 + 4 files changed, 241 insertions(+), 3 deletions(-) create mode 100644 examples/training/smolvla2_hirobot.slurm diff --git a/examples/training/smolvla2_hirobot.slurm b/examples/training/smolvla2_hirobot.slurm new file mode 100644 index 000000000..9ea327f91 --- /dev/null +++ b/examples/training/smolvla2_hirobot.slurm @@ -0,0 +1,84 @@ +#!/bin/bash +#SBATCH --job-name=smolvla2-hirobot +#SBATCH --partition=hopper-prod +#SBATCH --qos=high +#SBATCH --time=48:00:00 +#SBATCH --ntasks=1 +#SBATCH --gpus-per-task=8 + +# SmolVLA2 training on an annotated dataset, with image augmentation +# and per-component prompt dropout enabled — the two regularisers +# that move the model away from the "text_loss=6e-6 memorised one +# epoch worth of frames" failure mode toward "learns concepts, not +# pixels". +# +# What the regularisers do: +# +# * --dataset.image_transforms.enable=true: applies torchvision +# v2 ColorJitter (brightness/contrast/saturation/hue), +# SharpnessJitter and RandomAffine per frame at training time. +# Set max_num_transforms to control how many are sampled per +# frame; defaults to 3 of the 6. +# * --policy.plan_dropout_prob / memory / subtask: at training, +# randomly drop the context messages that carry the named +# binding so the model is forced to handle missing/stale context. +# Mirrors Pi0.7's prompt-component dropout (§V.E). +# +# Expected effect: text_loss plateaus higher (~0.5-2.0 instead of +# ~1e-5) and the model handles slight prompt/scene drift at +# inference instead of collapsing to memorised fragments. + +set -euo pipefail + +cd "${LEROBOT_ROOT:-$HOME/lerobot}" + +export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH" +export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}" +export NCCL_TIMEOUT="${NCCL_TIMEOUT:-1800}" +export HF_HUB_DOWNLOAD_TIMEOUT="${HF_HUB_DOWNLOAD_TIMEOUT:-120}" +export WANDB_INIT_TIMEOUT="${WANDB_INIT_TIMEOUT:-300}" + +DATASET="${DATASET:-pepijn223/super_poulain_full_tool3}" +POLICY_REPO_ID="${POLICY_REPO_ID:-pepijn223/smolvla2_hirobot_super_poulain_tool4}" +JOB_NAME="${JOB_NAME:-smolvla2-hirobot-super-poulain-tool4}" +NUM_PROCESSES="${NUM_PROCESSES:-8}" +BATCH_SIZE="${BATCH_SIZE:-32}" +STEPS="${STEPS:-10000}" +RUN_ID="${SLURM_JOB_ID:-$(date +%Y%m%d_%H%M%S)}" +OUTPUT_DIR="${OUTPUT_DIR:-/fsx/pepijn/outputs/train/smolvla2_hirobot_${RUN_ID}}" + +echo "Training smolvla2 on $DATASET" +echo " GPUs: $NUM_PROCESSES" +echo " batch: $BATCH_SIZE / GPU (global=$((NUM_PROCESSES * BATCH_SIZE)))" +echo " steps: $STEPS" +echo " output: $OUTPUT_DIR" +echo " augmentation: image_transforms ON, prompt dropout {plan:0.15 memory:0.15 subtask:0.20}" + +accelerate launch --multi_gpu --num_processes="$NUM_PROCESSES" \ + -m lerobot.scripts.lerobot_train \ + --policy.type=smolvla2 \ + --policy.recipe_path=recipes/smolvla2_hirobot.yaml \ + --dataset.repo_id="$DATASET" \ + --dataset.revision=main \ + --dataset.video_backend=pyav \ + --dataset.image_transforms.enable=true \ + --dataset.image_transforms.max_num_transforms=3 \ + --dataset.image_transforms.random_order=true \ + --policy.plan_dropout_prob=0.15 \ + --policy.memory_dropout_prob=0.15 \ + --policy.subtask_dropout_prob=0.20 \ + --output_dir="$OUTPUT_DIR" \ + --job_name="$JOB_NAME" \ + --policy.repo_id="$POLICY_REPO_ID" \ + --policy.compile_model=false \ + --policy.device=cuda \ + --policy.tokenizer_max_length=512 \ + --steps="$STEPS" \ + --policy.scheduler_decay_steps="$STEPS" \ + --batch_size="$BATCH_SIZE" \ + --wandb.enable=true \ + --wandb.disable_artifact=true \ + --wandb.project=hirobot \ + --log_freq=100 \ + --save_freq=1000 \ + --num_workers=0 diff --git a/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py b/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py index bebcdd04f..ddb08645e 100644 --- a/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py +++ b/src/lerobot/policies/smolvla2/chat_processor_smolvla2.py @@ -70,6 +70,22 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep): padding: str = "longest" padding_side: str = "right" tools: list[dict[str, Any]] | None = None + # --- Per-component prompt dropout (Pi0.7 §V.E, plan follow-up + # ``feat/pi05-prompt-dropout``). At training, drop non-target + # messages whose content was substituted from the named recipe + # binding with the given probability. Forces the model to handle + # missing context at inference — directly attacks the memorisation + # collapse where ``current_subtask=""`` puts the prompt OOD. All + # default to 0.0 (no dropout) so behaviour is identical until + # explicitly opted in via the training config. + plan_dropout_prob: float = 0.0 + memory_dropout_prob: float = 0.0 + subtask_dropout_prob: float = 0.0 + interjection_dropout_prob: float = 0.0 + # Optional seed for the per-sample RNG. ``None`` ⇒ use + # ``sample_idx`` derived from the transition (when present), so + # dropout is reproducible across runs but varies per sample. + dropout_seed: int | None = None def __post_init__(self) -> None: # Lazy: don't load the tokenizer until the step actually runs, @@ -101,19 +117,38 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep): tokenizer = self._get_tokenizer() + # Pull a sample_idx for the dropout RNG. ``index`` is the + # canonical per-frame key on ``LeRobotDataset`` samples and + # flows through into ``COMPLEMENTARY_DATA`` unchanged. When + # absent (e.g. inference) we fall back to 0 which is harmless + # because the dropout probs are also 0 at inference time. + sample_idx_raw = comp.get("index") + if hasattr(sample_idx_raw, "item"): + try: + sample_idx_raw = sample_idx_raw.item() + except Exception: # noqa: BLE001 + pass + if _is_batched_messages(messages): + indices_iter = ( + sample_idx_raw + if isinstance(sample_idx_raw, (list, tuple)) + else [sample_idx_raw] * len(messages) + ) encoded = [ self._encode_messages( tokenizer, msg, list(streams), - sorted(int(i) for i in indices), + sorted(int(i) for i in tgt_indices), + sample_idx=int(s_idx) if s_idx is not None else None, ) - for msg, streams, indices in zip( + for msg, streams, tgt_indices, s_idx in zip( messages, comp.get("message_streams") or [[] for _ in messages], comp.get("target_message_indices") or [[] for _ in messages], - strict=True, + indices_iter, + strict=False, ) ] else: @@ -123,6 +158,7 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep): messages, list(comp.get("message_streams") or []), sorted(int(i) for i in (comp.get("target_message_indices") or [])), + sample_idx=int(sample_idx_raw) if sample_idx_raw is not None else None, ) ] @@ -190,7 +226,15 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep): messages: list[dict[str, Any]], message_streams: list[str | None], target_indices: list[int], + sample_idx: int | None = None, ) -> tuple[list[int], list[int], bool]: + # Apply per-component prompt dropout *before* tokenisation, so + # the dropped messages don't contribute tokens or label-mask + # positions at all. Re-maps ``target_indices`` to account for + # removed messages. + messages, target_indices = self._apply_prompt_dropout( + messages, target_indices, sample_idx + ) text_messages = [_strip_lerobot_blocks(m) for m in messages] full_ids = tokenizer.apply_chat_template( @@ -231,6 +275,62 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep): ) return [int(i) for i in full_ids], labels, predict_actions + def _apply_prompt_dropout( + self, + messages: list[dict[str, Any]], + target_indices: list[int], + sample_idx: int | None, + ) -> tuple[list[dict[str, Any]], list[int]]: + """Probabilistically drop non-target context messages. + + Heuristic content sniffing — matches the prefix strings that + ``smolvla2_hirobot.yaml``'s recipes use when injecting plan / + memory / subtask / interjection content. Anything else is + kept unchanged. Target messages are never dropped (we still + need their tokens for supervision). + + Returns ``(new_messages, new_target_indices)`` where the + indices are re-mapped to point at the same target turns in + the trimmed list. + """ + probs = { + "plan": float(self.plan_dropout_prob or 0.0), + "memory": float(self.memory_dropout_prob or 0.0), + "subtask": float(self.subtask_dropout_prob or 0.0), + "interjection": float(self.interjection_dropout_prob or 0.0), + } + if not any(p > 0.0 for p in probs.values()): + return messages, target_indices + + # Deterministic per-sample RNG so dropout is reproducible + # across runs (matters for debugging / repro) but varies + # frame-to-frame. + import random # noqa: PLC0415 + + seed_int = self.dropout_seed if self.dropout_seed is not None else (sample_idx or 0) + rng = random.Random(int(seed_int) & 0xFFFFFFFF) + + target_set = set(target_indices) + keep_flags: list[bool] = [] + for i, msg in enumerate(messages): + if i in target_set: + keep_flags.append(True) + continue + kind = _classify_message_for_dropout(msg) + if kind and rng.random() < probs.get(kind, 0.0): + keep_flags.append(False) + else: + keep_flags.append(True) + + new_messages = [m for m, keep in zip(messages, keep_flags) if keep] + # Re-map target_indices: each old index drops by the count of + # falsy flags before it. + new_target_indices: list[int] = [] + for old_idx in target_indices: + dropped_before = sum(1 for k in keep_flags[:old_idx] if not k) + new_target_indices.append(old_idx - dropped_before) + return new_messages, sorted(new_target_indices) + def transform_features( self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]] ) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]: @@ -296,6 +396,39 @@ def _is_batched_messages(messages: Any) -> bool: return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list) +def _classify_message_for_dropout(message: dict[str, Any]) -> str | None: + """Best-effort classification of which recipe binding contributed + to this message, used for per-component dropout. + + The canonical recipe authors plan/memory/subtask injections with + distinctive prefix strings in the rendered content. Matching on + those prefixes is brittle if a future recipe author uses + different wording — but it's also localised to one place and + only affects the dropout fraction (never the actual semantics). + Returns ``None`` for messages we don't recognise; those are + always kept. + """ + content = message.get("content") + if isinstance(content, list): + text_parts: list[str] = [] + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + t = block.get("text") + if isinstance(t, str): + text_parts.append(t) + content = "\n".join(text_parts) + if not isinstance(content, str): + return None + head = content.lstrip().lower() + if head.startswith("plan:") or head.startswith("previous plan"): + return "plan" + if head.startswith("memory:") or head.startswith("previous memory"): + return "memory" + if head.startswith("current subtask") or head.startswith("completed subtask"): + return "subtask" + return None + + def _as_token_ids(value: Any) -> list[int]: if isinstance(value, dict) or (hasattr(value, "keys") and "input_ids" in value.keys()): value = value["input_ids"] diff --git a/src/lerobot/policies/smolvla2/configuration_smolvla2.py b/src/lerobot/policies/smolvla2/configuration_smolvla2.py index 6c3e1d898..99ce917e3 100644 --- a/src/lerobot/policies/smolvla2/configuration_smolvla2.py +++ b/src/lerobot/policies/smolvla2/configuration_smolvla2.py @@ -84,6 +84,24 @@ class SmolVLA2Config(SmolVLAConfig): effectively reduces SmolVLA2 back to SmolVLA's flow-only training, which is occasionally useful for ablations.""" + # Per-component prompt dropout (Pi0.7 §V.E) --------------------------- + # At training, randomly drop non-target context messages whose + # content was substituted from the named recipe binding. Forces + # the model to handle missing context — directly attacks the + # memorisation collapse where a stale or missing plan/memory at + # inference puts the prompt out-of-distribution and the LM head + # falls back to dominant-mode fragments. All default to 0.0 so + # behaviour is identical until explicitly enabled. + plan_dropout_prob: float = 0.0 + """Drop messages whose content starts with ``Plan:`` or ``Previous plan`` + with this probability per sample.""" + memory_dropout_prob: float = 0.0 + """Drop messages whose content starts with ``Memory:`` or ``Previous memory`` + with this probability per sample.""" + subtask_dropout_prob: float = 0.0 + """Drop messages whose content starts with ``Current subtask`` or + ``Completed subtask`` with this probability per sample.""" + def __post_init__(self) -> None: super().__post_init__() # Backbone needs gradients flowing through its text path when the diff --git a/src/lerobot/policies/smolvla2/processor_smolvla2.py b/src/lerobot/policies/smolvla2/processor_smolvla2.py index f844d08b5..93cbd0252 100644 --- a/src/lerobot/policies/smolvla2/processor_smolvla2.py +++ b/src/lerobot/policies/smolvla2/processor_smolvla2.py @@ -83,6 +83,9 @@ def make_smolvla2_pre_post_processors( tokenizer_name=config.vlm_model_name, max_length=config.tokenizer_max_length, padding=config.pad_language_to, + plan_dropout_prob=getattr(config, "plan_dropout_prob", 0.0), + memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0), + subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0), ), DeviceProcessorStep(device=config.device), NormalizerProcessorStep(