fix(smolvla2): bootstrap canonical task + plan/memory from dataset

The user-typed task and the dataset's canonical task differ in
wording (capitalisation, ``green box`` vs ``green bin``, etc.). With
``text_loss`` driven down to ~6e-6 across 78 epochs the model is
memorised on the *exact* rendered training prompts: any wording drift
puts the prompt out of distribution and the model collapses to its
dominant training mode (VQA JSON output).

When ``--dataset.repo_id`` is set, automatically:
  * read the canonical task string from the chosen episode (and use
    it as ``--task`` when the user didn't pass one);
  * pull the active ``plan`` / ``memory`` / ``subtask`` rows from the
    persistent slice (latest row whose timestamp ≤ start frame's
    timestamp — same semantics as the renderer's ``active_at``) and
    seed them into the runtime state.

The first prompt the runtime builds at REPL start now mirrors what
the recipe rendered during training (task + active plan + active
memory + optional current subtask). The user can still override any
of these by typing.

Memorisation itself is upstream (training mix collapsed to too few
unique high-level targets); this commit only fixes the inference-side
prompt mismatch that was making the memorisation surface as gibberish.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-05 14:00:36 +02:00
parent a47e535b02
commit 7a945d7bdc
@@ -307,6 +307,71 @@ def _build_observation_provider(
return _provider
def _bootstrap_state_from_dataset(
*,
dataset_repo_id: str,
episode: int,
start_frame: int,
) -> dict[str, str]:
"""Pull task / active plan / active memory / active subtask at ``start_frame``.
The model is heavily memorised on the exact training prompts the
recipe rendered from this dataset (canonical task wording,
persistent atoms emitted earlier in the episode). Reconstructing
that state at REPL startup lets the runtime's first prompt line
up with what training looked like — without it the model sees an
out-of-distribution prompt and falls back to its dominant
training mode (VQA JSON spam).
"""
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415
ds = LeRobotDataset(dataset_repo_id, episodes=[episode])
if len(ds) == 0:
return {}
idx = max(0, min(start_frame, len(ds) - 1))
sample = ds[idx]
out: dict[str, str] = {}
task = sample.get("task")
if isinstance(task, str) and task.strip():
out["task"] = task
persistent = sample.get("language_persistent") or []
# ``persistent`` is the broadcast slice of the episode; pick the
# *latest* row of each style whose ``timestamp`` is ≤ the
# frame's timestamp (matches the renderer's ``active_at``
# semantics).
try:
frame_ts = (
float(sample["timestamp"])
if not hasattr(sample["timestamp"], "item")
else sample["timestamp"].item()
)
except Exception: # noqa: BLE001
frame_ts = float("inf")
by_style: dict[str, tuple[float, str]] = {}
for row in persistent:
style = row.get("style")
ts = row.get("timestamp")
content = row.get("content")
if not (style and content) or ts is None:
continue
try:
ts_f = float(ts)
except (TypeError, ValueError):
continue
if ts_f > frame_ts:
continue
prev = by_style.get(style)
if prev is None or ts_f >= prev[0]:
by_style[style] = (ts_f, content)
for style, (_, content) in by_style.items():
if style in {"plan", "memory", "subtask"}:
out[style] = content
return out
def _build_tools(no_tts: bool, tts_voice: str) -> dict[str, Any]:
"""Instantiate the tools declared on this dataset/policy."""
if no_tts:
@@ -364,6 +429,7 @@ def main(argv: list[str] | None = None) -> int:
)
observation_provider: Callable[[], dict | None] | None = None
bootstrap_state: dict[str, str] = {}
if args.dataset_repo_id is not None:
print(
f"[smolvla2] streaming observations from {args.dataset_repo_id} "
@@ -379,6 +445,25 @@ def main(argv: list[str] | None = None) -> int:
preprocessor=preprocessor,
device=str(getattr(policy.config, "device", "cpu")),
)
# Pull the dataset's canonical task + the persistent atoms in
# force at the chosen start frame. The model is heavily
# memorised on the *exact* training prompts (task wording,
# current plan, current memory) — feeding ad-hoc user
# alternatives gives it nothing to recall against, so it
# collapses to its dominant training mode (VQA JSON). Reading
# the canonical state straight from the dataset gives the
# runtime a starting point that lines up with training.
bootstrap_state = _bootstrap_state_from_dataset(
dataset_repo_id=args.dataset_repo_id,
episode=args.dataset_episode,
start_frame=args.dataset_start_frame,
)
if bootstrap_state.get("task") and not args.task:
args.task = bootstrap_state["task"]
print(
f"[smolvla2] using canonical task from dataset: {args.task!r}",
flush=True,
)
tools = _build_tools(args.no_tts, args.tts_voice)
if tools:
@@ -411,6 +496,17 @@ def main(argv: list[str] | None = None) -> int:
)
if args.task:
runtime.set_task(args.task)
# Bootstrap plan/memory from the dataset so the first prompt the
# runtime builds matches what training rendered (task + active
# plan + active memory). Without this the runtime starts with
# plan/memory empty, which only matched the very-early frames in
# training and is an out-of-distribution prompt for the rest.
if bootstrap_state.get("plan"):
runtime.state["current_plan"] = bootstrap_state["plan"]
if bootstrap_state.get("memory"):
runtime.state["current_memory"] = bootstrap_state["memory"]
if bootstrap_state.get("subtask"):
runtime.state["current_subtask"] = bootstrap_state["subtask"]
return _run_repl(runtime, initial_task=args.task, max_ticks=args.max_ticks)