mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
fix(smolvla2): bootstrap canonical task + plan/memory from dataset
The user-typed task and the dataset's canonical task differ in
wording (capitalisation, ``green box`` vs ``green bin``, etc.). With
``text_loss`` driven down to ~6e-6 across 78 epochs the model is
memorised on the *exact* rendered training prompts: any wording drift
puts the prompt out of distribution and the model collapses to its
dominant training mode (VQA JSON output).
When ``--dataset.repo_id`` is set, automatically:
* read the canonical task string from the chosen episode (and use
it as ``--task`` when the user didn't pass one);
* pull the active ``plan`` / ``memory`` / ``subtask`` rows from the
persistent slice (latest row whose timestamp ≤ start frame's
timestamp — same semantics as the renderer's ``active_at``) and
seed them into the runtime state.
The first prompt the runtime builds at REPL start now mirrors what
the recipe rendered during training (task + active plan + active
memory + optional current subtask). The user can still override any
of these by typing.
Memorisation itself is upstream (training mix collapsed to too few
unique high-level targets); this commit only fixes the inference-side
prompt mismatch that was making the memorisation surface as gibberish.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -307,6 +307,71 @@ def _build_observation_provider(
|
||||
return _provider
|
||||
|
||||
|
||||
def _bootstrap_state_from_dataset(
|
||||
*,
|
||||
dataset_repo_id: str,
|
||||
episode: int,
|
||||
start_frame: int,
|
||||
) -> dict[str, str]:
|
||||
"""Pull task / active plan / active memory / active subtask at ``start_frame``.
|
||||
|
||||
The model is heavily memorised on the exact training prompts the
|
||||
recipe rendered from this dataset (canonical task wording,
|
||||
persistent atoms emitted earlier in the episode). Reconstructing
|
||||
that state at REPL startup lets the runtime's first prompt line
|
||||
up with what training looked like — without it the model sees an
|
||||
out-of-distribution prompt and falls back to its dominant
|
||||
training mode (VQA JSON spam).
|
||||
"""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415
|
||||
|
||||
ds = LeRobotDataset(dataset_repo_id, episodes=[episode])
|
||||
if len(ds) == 0:
|
||||
return {}
|
||||
idx = max(0, min(start_frame, len(ds) - 1))
|
||||
sample = ds[idx]
|
||||
|
||||
out: dict[str, str] = {}
|
||||
task = sample.get("task")
|
||||
if isinstance(task, str) and task.strip():
|
||||
out["task"] = task
|
||||
|
||||
persistent = sample.get("language_persistent") or []
|
||||
# ``persistent`` is the broadcast slice of the episode; pick the
|
||||
# *latest* row of each style whose ``timestamp`` is ≤ the
|
||||
# frame's timestamp (matches the renderer's ``active_at``
|
||||
# semantics).
|
||||
try:
|
||||
frame_ts = (
|
||||
float(sample["timestamp"])
|
||||
if not hasattr(sample["timestamp"], "item")
|
||||
else sample["timestamp"].item()
|
||||
)
|
||||
except Exception: # noqa: BLE001
|
||||
frame_ts = float("inf")
|
||||
|
||||
by_style: dict[str, tuple[float, str]] = {}
|
||||
for row in persistent:
|
||||
style = row.get("style")
|
||||
ts = row.get("timestamp")
|
||||
content = row.get("content")
|
||||
if not (style and content) or ts is None:
|
||||
continue
|
||||
try:
|
||||
ts_f = float(ts)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if ts_f > frame_ts:
|
||||
continue
|
||||
prev = by_style.get(style)
|
||||
if prev is None or ts_f >= prev[0]:
|
||||
by_style[style] = (ts_f, content)
|
||||
for style, (_, content) in by_style.items():
|
||||
if style in {"plan", "memory", "subtask"}:
|
||||
out[style] = content
|
||||
return out
|
||||
|
||||
|
||||
def _build_tools(no_tts: bool, tts_voice: str) -> dict[str, Any]:
|
||||
"""Instantiate the tools declared on this dataset/policy."""
|
||||
if no_tts:
|
||||
@@ -364,6 +429,7 @@ def main(argv: list[str] | None = None) -> int:
|
||||
)
|
||||
|
||||
observation_provider: Callable[[], dict | None] | None = None
|
||||
bootstrap_state: dict[str, str] = {}
|
||||
if args.dataset_repo_id is not None:
|
||||
print(
|
||||
f"[smolvla2] streaming observations from {args.dataset_repo_id} "
|
||||
@@ -379,6 +445,25 @@ def main(argv: list[str] | None = None) -> int:
|
||||
preprocessor=preprocessor,
|
||||
device=str(getattr(policy.config, "device", "cpu")),
|
||||
)
|
||||
# Pull the dataset's canonical task + the persistent atoms in
|
||||
# force at the chosen start frame. The model is heavily
|
||||
# memorised on the *exact* training prompts (task wording,
|
||||
# current plan, current memory) — feeding ad-hoc user
|
||||
# alternatives gives it nothing to recall against, so it
|
||||
# collapses to its dominant training mode (VQA JSON). Reading
|
||||
# the canonical state straight from the dataset gives the
|
||||
# runtime a starting point that lines up with training.
|
||||
bootstrap_state = _bootstrap_state_from_dataset(
|
||||
dataset_repo_id=args.dataset_repo_id,
|
||||
episode=args.dataset_episode,
|
||||
start_frame=args.dataset_start_frame,
|
||||
)
|
||||
if bootstrap_state.get("task") and not args.task:
|
||||
args.task = bootstrap_state["task"]
|
||||
print(
|
||||
f"[smolvla2] using canonical task from dataset: {args.task!r}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
tools = _build_tools(args.no_tts, args.tts_voice)
|
||||
if tools:
|
||||
@@ -411,6 +496,17 @@ def main(argv: list[str] | None = None) -> int:
|
||||
)
|
||||
if args.task:
|
||||
runtime.set_task(args.task)
|
||||
# Bootstrap plan/memory from the dataset so the first prompt the
|
||||
# runtime builds matches what training rendered (task + active
|
||||
# plan + active memory). Without this the runtime starts with
|
||||
# plan/memory empty, which only matched the very-early frames in
|
||||
# training and is an out-of-distribution prompt for the rest.
|
||||
if bootstrap_state.get("plan"):
|
||||
runtime.state["current_plan"] = bootstrap_state["plan"]
|
||||
if bootstrap_state.get("memory"):
|
||||
runtime.state["current_memory"] = bootstrap_state["memory"]
|
||||
if bootstrap_state.get("subtask"):
|
||||
runtime.state["current_subtask"] = bootstrap_state["subtask"]
|
||||
|
||||
return _run_repl(runtime, initial_task=args.task, max_ticks=args.max_ticks)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user