mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
feat(smolvla2): per-component prompt dropout + augmented training script
Two complementary regularisers to attack the ``text_loss=6e-6 = memorised one dataset`` failure mode that's making the model collapse on real-robot input: 1. **Per-component prompt dropout** (Pi0.7 §V.E / plan's ``feat/pi05-prompt-dropout`` follow-up). ``SmolVLA2ChatTokenizerStep`` gains ``plan_dropout_prob`` / ``memory_dropout_prob`` / ``subtask_dropout_prob`` knobs (default 0.0 — opt-in). At training, non-target messages whose rendered content starts with ``Plan:`` / ``Memory:`` / ``Current subtask:`` etc. are dropped with their respective probability before tokenisation, with a deterministic per-sample RNG keyed off the dataset ``index``. ``target_message_indices`` is re-mapped so the supervision still lands on the right turn. Forces the model to handle missing plan/memory/subtask context — directly attacks the real-robot collapse where a stale or empty plan field puts the prompt OOD. Surfaced on ``SmolVLA2Config`` as three floats so they're ``--policy.<knob>=<value>``-controllable from the train CLI; plumbed through ``make_smolvla2_pre_post_processors``. 2. **Image augmentation** is already wired in lerobot via ``--dataset.image_transforms.enable=true`` (torchvision v2 ColorJitter + SharpnessJitter + RandomAffine, default 3 of 6 sampled per frame). No code change needed — just a CLI flag. ``examples/training/smolvla2_hirobot.slurm`` shows the full training command with both enabled. Drop-in replacement for the ad-hoc SLURM script Pepijn was using locally; same args, plus the three dropout probs and the image-transforms flag. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,84 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#SBATCH --job-name=smolvla2-hirobot
|
||||||
|
#SBATCH --partition=hopper-prod
|
||||||
|
#SBATCH --qos=high
|
||||||
|
#SBATCH --time=48:00:00
|
||||||
|
#SBATCH --ntasks=1
|
||||||
|
#SBATCH --gpus-per-task=8
|
||||||
|
|
||||||
|
# SmolVLA2 training on an annotated dataset, with image augmentation
|
||||||
|
# and per-component prompt dropout enabled — the two regularisers
|
||||||
|
# that move the model away from the "text_loss=6e-6 memorised one
|
||||||
|
# epoch worth of frames" failure mode toward "learns concepts, not
|
||||||
|
# pixels".
|
||||||
|
#
|
||||||
|
# What the regularisers do:
|
||||||
|
#
|
||||||
|
# * --dataset.image_transforms.enable=true: applies torchvision
|
||||||
|
# v2 ColorJitter (brightness/contrast/saturation/hue),
|
||||||
|
# SharpnessJitter and RandomAffine per frame at training time.
|
||||||
|
# Set max_num_transforms to control how many are sampled per
|
||||||
|
# frame; defaults to 3 of the 6.
|
||||||
|
# * --policy.plan_dropout_prob / memory / subtask: at training,
|
||||||
|
# randomly drop the context messages that carry the named
|
||||||
|
# binding so the model is forced to handle missing/stale context.
|
||||||
|
# Mirrors Pi0.7's prompt-component dropout (§V.E).
|
||||||
|
#
|
||||||
|
# Expected effect: text_loss plateaus higher (~0.5-2.0 instead of
|
||||||
|
# ~1e-5) and the model handles slight prompt/scene drift at
|
||||||
|
# inference instead of collapsing to memorised fragments.
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
||||||
|
|
||||||
|
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
||||||
|
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
||||||
|
export NCCL_TIMEOUT="${NCCL_TIMEOUT:-1800}"
|
||||||
|
export HF_HUB_DOWNLOAD_TIMEOUT="${HF_HUB_DOWNLOAD_TIMEOUT:-120}"
|
||||||
|
export WANDB_INIT_TIMEOUT="${WANDB_INIT_TIMEOUT:-300}"
|
||||||
|
|
||||||
|
DATASET="${DATASET:-pepijn223/super_poulain_full_tool3}"
|
||||||
|
POLICY_REPO_ID="${POLICY_REPO_ID:-pepijn223/smolvla2_hirobot_super_poulain_tool4}"
|
||||||
|
JOB_NAME="${JOB_NAME:-smolvla2-hirobot-super-poulain-tool4}"
|
||||||
|
NUM_PROCESSES="${NUM_PROCESSES:-8}"
|
||||||
|
BATCH_SIZE="${BATCH_SIZE:-32}"
|
||||||
|
STEPS="${STEPS:-10000}"
|
||||||
|
RUN_ID="${SLURM_JOB_ID:-$(date +%Y%m%d_%H%M%S)}"
|
||||||
|
OUTPUT_DIR="${OUTPUT_DIR:-/fsx/pepijn/outputs/train/smolvla2_hirobot_${RUN_ID}}"
|
||||||
|
|
||||||
|
echo "Training smolvla2 on $DATASET"
|
||||||
|
echo " GPUs: $NUM_PROCESSES"
|
||||||
|
echo " batch: $BATCH_SIZE / GPU (global=$((NUM_PROCESSES * BATCH_SIZE)))"
|
||||||
|
echo " steps: $STEPS"
|
||||||
|
echo " output: $OUTPUT_DIR"
|
||||||
|
echo " augmentation: image_transforms ON, prompt dropout {plan:0.15 memory:0.15 subtask:0.20}"
|
||||||
|
|
||||||
|
accelerate launch --multi_gpu --num_processes="$NUM_PROCESSES" \
|
||||||
|
-m lerobot.scripts.lerobot_train \
|
||||||
|
--policy.type=smolvla2 \
|
||||||
|
--policy.recipe_path=recipes/smolvla2_hirobot.yaml \
|
||||||
|
--dataset.repo_id="$DATASET" \
|
||||||
|
--dataset.revision=main \
|
||||||
|
--dataset.video_backend=pyav \
|
||||||
|
--dataset.image_transforms.enable=true \
|
||||||
|
--dataset.image_transforms.max_num_transforms=3 \
|
||||||
|
--dataset.image_transforms.random_order=true \
|
||||||
|
--policy.plan_dropout_prob=0.15 \
|
||||||
|
--policy.memory_dropout_prob=0.15 \
|
||||||
|
--policy.subtask_dropout_prob=0.20 \
|
||||||
|
--output_dir="$OUTPUT_DIR" \
|
||||||
|
--job_name="$JOB_NAME" \
|
||||||
|
--policy.repo_id="$POLICY_REPO_ID" \
|
||||||
|
--policy.compile_model=false \
|
||||||
|
--policy.device=cuda \
|
||||||
|
--policy.tokenizer_max_length=512 \
|
||||||
|
--steps="$STEPS" \
|
||||||
|
--policy.scheduler_decay_steps="$STEPS" \
|
||||||
|
--batch_size="$BATCH_SIZE" \
|
||||||
|
--wandb.enable=true \
|
||||||
|
--wandb.disable_artifact=true \
|
||||||
|
--wandb.project=hirobot \
|
||||||
|
--log_freq=100 \
|
||||||
|
--save_freq=1000 \
|
||||||
|
--num_workers=0
|
||||||
@@ -70,6 +70,22 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
|
|||||||
padding: str = "longest"
|
padding: str = "longest"
|
||||||
padding_side: str = "right"
|
padding_side: str = "right"
|
||||||
tools: list[dict[str, Any]] | None = None
|
tools: list[dict[str, Any]] | None = None
|
||||||
|
# --- Per-component prompt dropout (Pi0.7 §V.E, plan follow-up
|
||||||
|
# ``feat/pi05-prompt-dropout``). At training, drop non-target
|
||||||
|
# messages whose content was substituted from the named recipe
|
||||||
|
# binding with the given probability. Forces the model to handle
|
||||||
|
# missing context at inference — directly attacks the memorisation
|
||||||
|
# collapse where ``current_subtask=""`` puts the prompt OOD. All
|
||||||
|
# default to 0.0 (no dropout) so behaviour is identical until
|
||||||
|
# explicitly opted in via the training config.
|
||||||
|
plan_dropout_prob: float = 0.0
|
||||||
|
memory_dropout_prob: float = 0.0
|
||||||
|
subtask_dropout_prob: float = 0.0
|
||||||
|
interjection_dropout_prob: float = 0.0
|
||||||
|
# Optional seed for the per-sample RNG. ``None`` ⇒ use
|
||||||
|
# ``sample_idx`` derived from the transition (when present), so
|
||||||
|
# dropout is reproducible across runs but varies per sample.
|
||||||
|
dropout_seed: int | None = None
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
# Lazy: don't load the tokenizer until the step actually runs,
|
# Lazy: don't load the tokenizer until the step actually runs,
|
||||||
@@ -101,19 +117,38 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
|
|||||||
|
|
||||||
tokenizer = self._get_tokenizer()
|
tokenizer = self._get_tokenizer()
|
||||||
|
|
||||||
|
# Pull a sample_idx for the dropout RNG. ``index`` is the
|
||||||
|
# canonical per-frame key on ``LeRobotDataset`` samples and
|
||||||
|
# flows through into ``COMPLEMENTARY_DATA`` unchanged. When
|
||||||
|
# absent (e.g. inference) we fall back to 0 which is harmless
|
||||||
|
# because the dropout probs are also 0 at inference time.
|
||||||
|
sample_idx_raw = comp.get("index")
|
||||||
|
if hasattr(sample_idx_raw, "item"):
|
||||||
|
try:
|
||||||
|
sample_idx_raw = sample_idx_raw.item()
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
pass
|
||||||
|
|
||||||
if _is_batched_messages(messages):
|
if _is_batched_messages(messages):
|
||||||
|
indices_iter = (
|
||||||
|
sample_idx_raw
|
||||||
|
if isinstance(sample_idx_raw, (list, tuple))
|
||||||
|
else [sample_idx_raw] * len(messages)
|
||||||
|
)
|
||||||
encoded = [
|
encoded = [
|
||||||
self._encode_messages(
|
self._encode_messages(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
msg,
|
msg,
|
||||||
list(streams),
|
list(streams),
|
||||||
sorted(int(i) for i in indices),
|
sorted(int(i) for i in tgt_indices),
|
||||||
|
sample_idx=int(s_idx) if s_idx is not None else None,
|
||||||
)
|
)
|
||||||
for msg, streams, indices in zip(
|
for msg, streams, tgt_indices, s_idx in zip(
|
||||||
messages,
|
messages,
|
||||||
comp.get("message_streams") or [[] for _ in messages],
|
comp.get("message_streams") or [[] for _ in messages],
|
||||||
comp.get("target_message_indices") or [[] for _ in messages],
|
comp.get("target_message_indices") or [[] for _ in messages],
|
||||||
strict=True,
|
indices_iter,
|
||||||
|
strict=False,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
@@ -123,6 +158,7 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
|
|||||||
messages,
|
messages,
|
||||||
list(comp.get("message_streams") or []),
|
list(comp.get("message_streams") or []),
|
||||||
sorted(int(i) for i in (comp.get("target_message_indices") or [])),
|
sorted(int(i) for i in (comp.get("target_message_indices") or [])),
|
||||||
|
sample_idx=int(sample_idx_raw) if sample_idx_raw is not None else None,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -190,7 +226,15 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
|
|||||||
messages: list[dict[str, Any]],
|
messages: list[dict[str, Any]],
|
||||||
message_streams: list[str | None],
|
message_streams: list[str | None],
|
||||||
target_indices: list[int],
|
target_indices: list[int],
|
||||||
|
sample_idx: int | None = None,
|
||||||
) -> tuple[list[int], list[int], bool]:
|
) -> tuple[list[int], list[int], bool]:
|
||||||
|
# Apply per-component prompt dropout *before* tokenisation, so
|
||||||
|
# the dropped messages don't contribute tokens or label-mask
|
||||||
|
# positions at all. Re-maps ``target_indices`` to account for
|
||||||
|
# removed messages.
|
||||||
|
messages, target_indices = self._apply_prompt_dropout(
|
||||||
|
messages, target_indices, sample_idx
|
||||||
|
)
|
||||||
text_messages = [_strip_lerobot_blocks(m) for m in messages]
|
text_messages = [_strip_lerobot_blocks(m) for m in messages]
|
||||||
|
|
||||||
full_ids = tokenizer.apply_chat_template(
|
full_ids = tokenizer.apply_chat_template(
|
||||||
@@ -231,6 +275,62 @@ class SmolVLA2ChatTokenizerStep(ProcessorStep):
|
|||||||
)
|
)
|
||||||
return [int(i) for i in full_ids], labels, predict_actions
|
return [int(i) for i in full_ids], labels, predict_actions
|
||||||
|
|
||||||
|
def _apply_prompt_dropout(
|
||||||
|
self,
|
||||||
|
messages: list[dict[str, Any]],
|
||||||
|
target_indices: list[int],
|
||||||
|
sample_idx: int | None,
|
||||||
|
) -> tuple[list[dict[str, Any]], list[int]]:
|
||||||
|
"""Probabilistically drop non-target context messages.
|
||||||
|
|
||||||
|
Heuristic content sniffing — matches the prefix strings that
|
||||||
|
``smolvla2_hirobot.yaml``'s recipes use when injecting plan /
|
||||||
|
memory / subtask / interjection content. Anything else is
|
||||||
|
kept unchanged. Target messages are never dropped (we still
|
||||||
|
need their tokens for supervision).
|
||||||
|
|
||||||
|
Returns ``(new_messages, new_target_indices)`` where the
|
||||||
|
indices are re-mapped to point at the same target turns in
|
||||||
|
the trimmed list.
|
||||||
|
"""
|
||||||
|
probs = {
|
||||||
|
"plan": float(self.plan_dropout_prob or 0.0),
|
||||||
|
"memory": float(self.memory_dropout_prob or 0.0),
|
||||||
|
"subtask": float(self.subtask_dropout_prob or 0.0),
|
||||||
|
"interjection": float(self.interjection_dropout_prob or 0.0),
|
||||||
|
}
|
||||||
|
if not any(p > 0.0 for p in probs.values()):
|
||||||
|
return messages, target_indices
|
||||||
|
|
||||||
|
# Deterministic per-sample RNG so dropout is reproducible
|
||||||
|
# across runs (matters for debugging / repro) but varies
|
||||||
|
# frame-to-frame.
|
||||||
|
import random # noqa: PLC0415
|
||||||
|
|
||||||
|
seed_int = self.dropout_seed if self.dropout_seed is not None else (sample_idx or 0)
|
||||||
|
rng = random.Random(int(seed_int) & 0xFFFFFFFF)
|
||||||
|
|
||||||
|
target_set = set(target_indices)
|
||||||
|
keep_flags: list[bool] = []
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
if i in target_set:
|
||||||
|
keep_flags.append(True)
|
||||||
|
continue
|
||||||
|
kind = _classify_message_for_dropout(msg)
|
||||||
|
if kind and rng.random() < probs.get(kind, 0.0):
|
||||||
|
keep_flags.append(False)
|
||||||
|
else:
|
||||||
|
keep_flags.append(True)
|
||||||
|
|
||||||
|
new_messages = [m for m, keep in zip(messages, keep_flags) if keep]
|
||||||
|
# Re-map target_indices: each old index drops by the count of
|
||||||
|
# falsy flags before it.
|
||||||
|
new_target_indices: list[int] = []
|
||||||
|
for old_idx in target_indices:
|
||||||
|
dropped_before = sum(1 for k in keep_flags[:old_idx] if not k)
|
||||||
|
new_target_indices.append(old_idx - dropped_before)
|
||||||
|
return new_messages, sorted(new_target_indices)
|
||||||
|
|
||||||
def transform_features(
|
def transform_features(
|
||||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
@@ -296,6 +396,39 @@ def _is_batched_messages(messages: Any) -> bool:
|
|||||||
return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list)
|
return isinstance(messages, list) and bool(messages) and isinstance(messages[0], list)
|
||||||
|
|
||||||
|
|
||||||
|
def _classify_message_for_dropout(message: dict[str, Any]) -> str | None:
|
||||||
|
"""Best-effort classification of which recipe binding contributed
|
||||||
|
to this message, used for per-component dropout.
|
||||||
|
|
||||||
|
The canonical recipe authors plan/memory/subtask injections with
|
||||||
|
distinctive prefix strings in the rendered content. Matching on
|
||||||
|
those prefixes is brittle if a future recipe author uses
|
||||||
|
different wording — but it's also localised to one place and
|
||||||
|
only affects the dropout fraction (never the actual semantics).
|
||||||
|
Returns ``None`` for messages we don't recognise; those are
|
||||||
|
always kept.
|
||||||
|
"""
|
||||||
|
content = message.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
text_parts: list[str] = []
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "text":
|
||||||
|
t = block.get("text")
|
||||||
|
if isinstance(t, str):
|
||||||
|
text_parts.append(t)
|
||||||
|
content = "\n".join(text_parts)
|
||||||
|
if not isinstance(content, str):
|
||||||
|
return None
|
||||||
|
head = content.lstrip().lower()
|
||||||
|
if head.startswith("plan:") or head.startswith("previous plan"):
|
||||||
|
return "plan"
|
||||||
|
if head.startswith("memory:") or head.startswith("previous memory"):
|
||||||
|
return "memory"
|
||||||
|
if head.startswith("current subtask") or head.startswith("completed subtask"):
|
||||||
|
return "subtask"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _as_token_ids(value: Any) -> list[int]:
|
def _as_token_ids(value: Any) -> list[int]:
|
||||||
if isinstance(value, dict) or (hasattr(value, "keys") and "input_ids" in value.keys()):
|
if isinstance(value, dict) or (hasattr(value, "keys") and "input_ids" in value.keys()):
|
||||||
value = value["input_ids"]
|
value = value["input_ids"]
|
||||||
|
|||||||
@@ -84,6 +84,24 @@ class SmolVLA2Config(SmolVLAConfig):
|
|||||||
effectively reduces SmolVLA2 back to SmolVLA's flow-only training,
|
effectively reduces SmolVLA2 back to SmolVLA's flow-only training,
|
||||||
which is occasionally useful for ablations."""
|
which is occasionally useful for ablations."""
|
||||||
|
|
||||||
|
# Per-component prompt dropout (Pi0.7 §V.E) ---------------------------
|
||||||
|
# At training, randomly drop non-target context messages whose
|
||||||
|
# content was substituted from the named recipe binding. Forces
|
||||||
|
# the model to handle missing context — directly attacks the
|
||||||
|
# memorisation collapse where a stale or missing plan/memory at
|
||||||
|
# inference puts the prompt out-of-distribution and the LM head
|
||||||
|
# falls back to dominant-mode fragments. All default to 0.0 so
|
||||||
|
# behaviour is identical until explicitly enabled.
|
||||||
|
plan_dropout_prob: float = 0.0
|
||||||
|
"""Drop messages whose content starts with ``Plan:`` or ``Previous plan``
|
||||||
|
with this probability per sample."""
|
||||||
|
memory_dropout_prob: float = 0.0
|
||||||
|
"""Drop messages whose content starts with ``Memory:`` or ``Previous memory``
|
||||||
|
with this probability per sample."""
|
||||||
|
subtask_dropout_prob: float = 0.0
|
||||||
|
"""Drop messages whose content starts with ``Current subtask`` or
|
||||||
|
``Completed subtask`` with this probability per sample."""
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
# Backbone needs gradients flowing through its text path when the
|
# Backbone needs gradients flowing through its text path when the
|
||||||
|
|||||||
@@ -83,6 +83,9 @@ def make_smolvla2_pre_post_processors(
|
|||||||
tokenizer_name=config.vlm_model_name,
|
tokenizer_name=config.vlm_model_name,
|
||||||
max_length=config.tokenizer_max_length,
|
max_length=config.tokenizer_max_length,
|
||||||
padding=config.pad_language_to,
|
padding=config.pad_language_to,
|
||||||
|
plan_dropout_prob=getattr(config, "plan_dropout_prob", 0.0),
|
||||||
|
memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0),
|
||||||
|
subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0),
|
||||||
),
|
),
|
||||||
DeviceProcessorStep(device=config.device),
|
DeviceProcessorStep(device=config.device),
|
||||||
NormalizerProcessorStep(
|
NormalizerProcessorStep(
|
||||||
|
|||||||
Reference in New Issue
Block a user