feat(smolvla2-runtime): --dataset.augment_at_inference for the bisection test

Apply the training-time torchvision-v2 ColorJitter / SharpnessJitter /
RandomAffine pipeline to dataset frames in dry-run, so we can isolate
whether the LM head's collapse to '\n' on live frames is:

  * pure scene-content OOD (unaugmented dataset frames work, mildly
    augmented ones still work — model has learned the augmentation
    distribution, only fails when the scene content itself diverges)
  * hyper-specific memorisation (dry-run with augmentation also
    collapses to '\n' — head is nailed to the exact unperturbed
    training samples and only the retrain helps)

Usage:

  lerobot-smolvla2-runtime --no_robot --policy.path=... \
    --dataset.repo_id=... --dataset.episode=0 \
    --dataset.start_frame=1000 \
    --dataset.augment_at_inference

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-12 18:14:57 +02:00
parent 0410705aff
commit 4852b9f952
@@ -119,6 +119,21 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
"frame by frame; set to 0 to freeze on ``start_frame``."
),
)
p.add_argument(
"--dataset.augment_at_inference",
dest="dataset_augment_at_inference",
action="store_true",
help=(
"Apply the same torchvision-v2 ColorJitter / SharpnessJitter "
"/ RandomAffine pipeline that training used to each dataset "
"frame fed to the policy. Use to test whether the LM head "
"generalises under the augmentation distribution it was "
"supervised on — if dry-run still produces coherent subtask "
"text with this flag on, the head has learned beyond exact "
"frames; if it collapses to '\\n' the head is hyper-specific "
"to the unperturbed training samples."
),
)
p.add_argument(
"--task",
dest="task",
@@ -301,6 +316,7 @@ def _build_observation_provider(
advance_per_tick: int,
preprocessor: Any,
device: str,
augment: bool = False,
) -> Callable[[], dict | None]:
"""Build a closure that feeds dataset frames into the runtime.
@@ -326,6 +342,31 @@ def _build_observation_provider(
f"Dataset {dataset_repo_id!r} episode {episode} is empty."
)
# Optional: apply the same torchvision-v2 augmentation pipeline
# that training used, so dry-run sees frames from the augmented
# support region (not just the unperturbed dataset frames). When
# the LM head still generates coherent text under this, it has
# learned over the augmentation distribution — the *opposite* of
# the "memorised one specific frame per supervision" failure
# mode. When it collapses to ``\n`` here too, the head is hyper-
# specific to the unperturbed training samples and only the
# retrain can help.
inference_aug = None
if augment:
from lerobot.transforms import ( # noqa: PLC0415
ImageTransforms,
ImageTransformsConfig,
)
aug_cfg = ImageTransformsConfig(enable=True)
inference_aug = ImageTransforms(aug_cfg)
ds.set_image_transforms(inference_aug)
logger.warning(
"dry-run augmentation ENABLED — frames will be jittered "
"(brightness/contrast/saturation/hue/sharpness/affine) "
"before going to the policy"
)
state = {"cursor": max(0, min(start_frame, len(ds) - 1))}
def _provider() -> dict | None:
@@ -1185,6 +1226,7 @@ def main(argv: list[str] | None = None) -> int:
advance_per_tick=args.dataset_advance_per_tick,
preprocessor=preprocessor,
device=str(getattr(policy.config, "device", "cpu")),
augment=getattr(args, "dataset_augment_at_inference", False),
)
tools = _build_tools(args.no_tts, args.tts_voice)