mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-02 23:57:24 +00:00
feat(smolvla2-runtime): --dataset.augment_at_inference for the bisection test
Apply the training-time torchvision-v2 ColorJitter / SharpnessJitter /
RandomAffine pipeline to dataset frames in dry-run, so we can isolate
whether the LM head's collapse to '\n' on live frames is:
* pure scene-content OOD (unaugmented dataset frames work, mildly
augmented ones still work — model has learned the augmentation
distribution, only fails when the scene content itself diverges)
* hyper-specific memorisation (dry-run with augmentation also
collapses to '\n' — head is nailed to the exact unperturbed
training samples and only the retrain helps)
Usage:
lerobot-smolvla2-runtime --no_robot --policy.path=... \
--dataset.repo_id=... --dataset.episode=0 \
--dataset.start_frame=1000 \
--dataset.augment_at_inference
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -119,6 +119,21 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||
"frame by frame; set to 0 to freeze on ``start_frame``."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--dataset.augment_at_inference",
|
||||
dest="dataset_augment_at_inference",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Apply the same torchvision-v2 ColorJitter / SharpnessJitter "
|
||||
"/ RandomAffine pipeline that training used to each dataset "
|
||||
"frame fed to the policy. Use to test whether the LM head "
|
||||
"generalises under the augmentation distribution it was "
|
||||
"supervised on — if dry-run still produces coherent subtask "
|
||||
"text with this flag on, the head has learned beyond exact "
|
||||
"frames; if it collapses to '\\n' the head is hyper-specific "
|
||||
"to the unperturbed training samples."
|
||||
),
|
||||
)
|
||||
p.add_argument(
|
||||
"--task",
|
||||
dest="task",
|
||||
@@ -301,6 +316,7 @@ def _build_observation_provider(
|
||||
advance_per_tick: int,
|
||||
preprocessor: Any,
|
||||
device: str,
|
||||
augment: bool = False,
|
||||
) -> Callable[[], dict | None]:
|
||||
"""Build a closure that feeds dataset frames into the runtime.
|
||||
|
||||
@@ -326,6 +342,31 @@ def _build_observation_provider(
|
||||
f"Dataset {dataset_repo_id!r} episode {episode} is empty."
|
||||
)
|
||||
|
||||
# Optional: apply the same torchvision-v2 augmentation pipeline
|
||||
# that training used, so dry-run sees frames from the augmented
|
||||
# support region (not just the unperturbed dataset frames). When
|
||||
# the LM head still generates coherent text under this, it has
|
||||
# learned over the augmentation distribution — the *opposite* of
|
||||
# the "memorised one specific frame per supervision" failure
|
||||
# mode. When it collapses to ``\n`` here too, the head is hyper-
|
||||
# specific to the unperturbed training samples and only the
|
||||
# retrain can help.
|
||||
inference_aug = None
|
||||
if augment:
|
||||
from lerobot.transforms import ( # noqa: PLC0415
|
||||
ImageTransforms,
|
||||
ImageTransformsConfig,
|
||||
)
|
||||
|
||||
aug_cfg = ImageTransformsConfig(enable=True)
|
||||
inference_aug = ImageTransforms(aug_cfg)
|
||||
ds.set_image_transforms(inference_aug)
|
||||
logger.warning(
|
||||
"dry-run augmentation ENABLED — frames will be jittered "
|
||||
"(brightness/contrast/saturation/hue/sharpness/affine) "
|
||||
"before going to the policy"
|
||||
)
|
||||
|
||||
state = {"cursor": max(0, min(start_frame, len(ds) - 1))}
|
||||
|
||||
def _provider() -> dict | None:
|
||||
@@ -1185,6 +1226,7 @@ def main(argv: list[str] | None = None) -> int:
|
||||
advance_per_tick=args.dataset_advance_per_tick,
|
||||
preprocessor=preprocessor,
|
||||
device=str(getattr(policy.config, "device", "cpu")),
|
||||
augment=getattr(args, "dataset_augment_at_inference", False),
|
||||
)
|
||||
|
||||
tools = _build_tools(args.no_tts, args.tts_voice)
|
||||
|
||||
Reference in New Issue
Block a user