diff --git a/src/lerobot/scripts/lerobot_smolvla2_runtime.py b/src/lerobot/scripts/lerobot_smolvla2_runtime.py index bd9ac8e7b..45e0cf23d 100644 --- a/src/lerobot/scripts/lerobot_smolvla2_runtime.py +++ b/src/lerobot/scripts/lerobot_smolvla2_runtime.py @@ -307,6 +307,71 @@ def _build_observation_provider( return _provider +def _bootstrap_state_from_dataset( + *, + dataset_repo_id: str, + episode: int, + start_frame: int, +) -> dict[str, str]: + """Pull task / active plan / active memory / active subtask at ``start_frame``. + + The model is heavily memorised on the exact training prompts the + recipe rendered from this dataset (canonical task wording, + persistent atoms emitted earlier in the episode). Reconstructing + that state at REPL startup lets the runtime's first prompt line + up with what training looked like — without it the model sees an + out-of-distribution prompt and falls back to its dominant + training mode (VQA JSON spam). + """ + from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415 + + ds = LeRobotDataset(dataset_repo_id, episodes=[episode]) + if len(ds) == 0: + return {} + idx = max(0, min(start_frame, len(ds) - 1)) + sample = ds[idx] + + out: dict[str, str] = {} + task = sample.get("task") + if isinstance(task, str) and task.strip(): + out["task"] = task + + persistent = sample.get("language_persistent") or [] + # ``persistent`` is the broadcast slice of the episode; pick the + # *latest* row of each style whose ``timestamp`` is ≤ the + # frame's timestamp (matches the renderer's ``active_at`` + # semantics). + try: + frame_ts = ( + float(sample["timestamp"]) + if not hasattr(sample["timestamp"], "item") + else sample["timestamp"].item() + ) + except Exception: # noqa: BLE001 + frame_ts = float("inf") + + by_style: dict[str, tuple[float, str]] = {} + for row in persistent: + style = row.get("style") + ts = row.get("timestamp") + content = row.get("content") + if not (style and content) or ts is None: + continue + try: + ts_f = float(ts) + except (TypeError, ValueError): + continue + if ts_f > frame_ts: + continue + prev = by_style.get(style) + if prev is None or ts_f >= prev[0]: + by_style[style] = (ts_f, content) + for style, (_, content) in by_style.items(): + if style in {"plan", "memory", "subtask"}: + out[style] = content + return out + + def _build_tools(no_tts: bool, tts_voice: str) -> dict[str, Any]: """Instantiate the tools declared on this dataset/policy.""" if no_tts: @@ -364,6 +429,7 @@ def main(argv: list[str] | None = None) -> int: ) observation_provider: Callable[[], dict | None] | None = None + bootstrap_state: dict[str, str] = {} if args.dataset_repo_id is not None: print( f"[smolvla2] streaming observations from {args.dataset_repo_id} " @@ -379,6 +445,25 @@ def main(argv: list[str] | None = None) -> int: preprocessor=preprocessor, device=str(getattr(policy.config, "device", "cpu")), ) + # Pull the dataset's canonical task + the persistent atoms in + # force at the chosen start frame. The model is heavily + # memorised on the *exact* training prompts (task wording, + # current plan, current memory) — feeding ad-hoc user + # alternatives gives it nothing to recall against, so it + # collapses to its dominant training mode (VQA JSON). Reading + # the canonical state straight from the dataset gives the + # runtime a starting point that lines up with training. + bootstrap_state = _bootstrap_state_from_dataset( + dataset_repo_id=args.dataset_repo_id, + episode=args.dataset_episode, + start_frame=args.dataset_start_frame, + ) + if bootstrap_state.get("task") and not args.task: + args.task = bootstrap_state["task"] + print( + f"[smolvla2] using canonical task from dataset: {args.task!r}", + flush=True, + ) tools = _build_tools(args.no_tts, args.tts_voice) if tools: @@ -411,6 +496,17 @@ def main(argv: list[str] | None = None) -> int: ) if args.task: runtime.set_task(args.task) + # Bootstrap plan/memory from the dataset so the first prompt the + # runtime builds matches what training rendered (task + active + # plan + active memory). Without this the runtime starts with + # plan/memory empty, which only matched the very-early frames in + # training and is an out-of-distribution prompt for the rest. + if bootstrap_state.get("plan"): + runtime.state["current_plan"] = bootstrap_state["plan"] + if bootstrap_state.get("memory"): + runtime.state["current_memory"] = bootstrap_state["memory"] + if bootstrap_state.get("subtask"): + runtime.state["current_subtask"] = bootstrap_state["subtask"] return _run_repl(runtime, initial_task=args.task, max_ticks=args.max_ticks)