diff --git a/src/lerobot/scripts/lerobot_smolvla2_runtime.py b/src/lerobot/scripts/lerobot_smolvla2_runtime.py index 58d9cd0dd..f0c390f8f 100644 --- a/src/lerobot/scripts/lerobot_smolvla2_runtime.py +++ b/src/lerobot/scripts/lerobot_smolvla2_runtime.py @@ -119,6 +119,21 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: "frame by frame; set to 0 to freeze on ``start_frame``." ), ) + p.add_argument( + "--dataset.augment_at_inference", + dest="dataset_augment_at_inference", + action="store_true", + help=( + "Apply the same torchvision-v2 ColorJitter / SharpnessJitter " + "/ RandomAffine pipeline that training used to each dataset " + "frame fed to the policy. Use to test whether the LM head " + "generalises under the augmentation distribution it was " + "supervised on — if dry-run still produces coherent subtask " + "text with this flag on, the head has learned beyond exact " + "frames; if it collapses to '\\n' the head is hyper-specific " + "to the unperturbed training samples." + ), + ) p.add_argument( "--task", dest="task", @@ -301,6 +316,7 @@ def _build_observation_provider( advance_per_tick: int, preprocessor: Any, device: str, + augment: bool = False, ) -> Callable[[], dict | None]: """Build a closure that feeds dataset frames into the runtime. @@ -326,6 +342,31 @@ def _build_observation_provider( f"Dataset {dataset_repo_id!r} episode {episode} is empty." ) + # Optional: apply the same torchvision-v2 augmentation pipeline + # that training used, so dry-run sees frames from the augmented + # support region (not just the unperturbed dataset frames). When + # the LM head still generates coherent text under this, it has + # learned over the augmentation distribution — the *opposite* of + # the "memorised one specific frame per supervision" failure + # mode. When it collapses to ``\n`` here too, the head is hyper- + # specific to the unperturbed training samples and only the + # retrain can help. + inference_aug = None + if augment: + from lerobot.transforms import ( # noqa: PLC0415 + ImageTransforms, + ImageTransformsConfig, + ) + + aug_cfg = ImageTransformsConfig(enable=True) + inference_aug = ImageTransforms(aug_cfg) + ds.set_image_transforms(inference_aug) + logger.warning( + "dry-run augmentation ENABLED — frames will be jittered " + "(brightness/contrast/saturation/hue/sharpness/affine) " + "before going to the policy" + ) + state = {"cursor": max(0, min(start_frame, len(ds) - 1))} def _provider() -> dict | None: @@ -1185,6 +1226,7 @@ def main(argv: list[str] | None = None) -> int: advance_per_tick=args.dataset_advance_per_tick, preprocessor=preprocessor, device=str(getattr(policy.config, "device", "cpu")), + augment=getattr(args, "dataset_augment_at_inference", False), ) tools = _build_tools(args.no_tts, args.tts_voice)