mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 05:59:52 +00:00
fix(smolvla2): bootstrap canonical task + plan/memory from dataset
The user-typed task and the dataset's canonical task differ in
wording (capitalisation, ``green box`` vs ``green bin``, etc.). With
``text_loss`` driven down to ~6e-6 across 78 epochs the model is
memorised on the *exact* rendered training prompts: any wording drift
puts the prompt out of distribution and the model collapses to its
dominant training mode (VQA JSON output).
When ``--dataset.repo_id`` is set, automatically:
* read the canonical task string from the chosen episode (and use
it as ``--task`` when the user didn't pass one);
* pull the active ``plan`` / ``memory`` / ``subtask`` rows from the
persistent slice (latest row whose timestamp ≤ start frame's
timestamp — same semantics as the renderer's ``active_at``) and
seed them into the runtime state.
The first prompt the runtime builds at REPL start now mirrors what
the recipe rendered during training (task + active plan + active
memory + optional current subtask). The user can still override any
of these by typing.
Memorisation itself is upstream (training mix collapsed to too few
unique high-level targets); this commit only fixes the inference-side
prompt mismatch that was making the memorisation surface as gibberish.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -307,6 +307,71 @@ def _build_observation_provider(
|
|||||||
return _provider
|
return _provider
|
||||||
|
|
||||||
|
|
||||||
|
def _bootstrap_state_from_dataset(
|
||||||
|
*,
|
||||||
|
dataset_repo_id: str,
|
||||||
|
episode: int,
|
||||||
|
start_frame: int,
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Pull task / active plan / active memory / active subtask at ``start_frame``.
|
||||||
|
|
||||||
|
The model is heavily memorised on the exact training prompts the
|
||||||
|
recipe rendered from this dataset (canonical task wording,
|
||||||
|
persistent atoms emitted earlier in the episode). Reconstructing
|
||||||
|
that state at REPL startup lets the runtime's first prompt line
|
||||||
|
up with what training looked like — without it the model sees an
|
||||||
|
out-of-distribution prompt and falls back to its dominant
|
||||||
|
training mode (VQA JSON spam).
|
||||||
|
"""
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: PLC0415
|
||||||
|
|
||||||
|
ds = LeRobotDataset(dataset_repo_id, episodes=[episode])
|
||||||
|
if len(ds) == 0:
|
||||||
|
return {}
|
||||||
|
idx = max(0, min(start_frame, len(ds) - 1))
|
||||||
|
sample = ds[idx]
|
||||||
|
|
||||||
|
out: dict[str, str] = {}
|
||||||
|
task = sample.get("task")
|
||||||
|
if isinstance(task, str) and task.strip():
|
||||||
|
out["task"] = task
|
||||||
|
|
||||||
|
persistent = sample.get("language_persistent") or []
|
||||||
|
# ``persistent`` is the broadcast slice of the episode; pick the
|
||||||
|
# *latest* row of each style whose ``timestamp`` is ≤ the
|
||||||
|
# frame's timestamp (matches the renderer's ``active_at``
|
||||||
|
# semantics).
|
||||||
|
try:
|
||||||
|
frame_ts = (
|
||||||
|
float(sample["timestamp"])
|
||||||
|
if not hasattr(sample["timestamp"], "item")
|
||||||
|
else sample["timestamp"].item()
|
||||||
|
)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
frame_ts = float("inf")
|
||||||
|
|
||||||
|
by_style: dict[str, tuple[float, str]] = {}
|
||||||
|
for row in persistent:
|
||||||
|
style = row.get("style")
|
||||||
|
ts = row.get("timestamp")
|
||||||
|
content = row.get("content")
|
||||||
|
if not (style and content) or ts is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
ts_f = float(ts)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
continue
|
||||||
|
if ts_f > frame_ts:
|
||||||
|
continue
|
||||||
|
prev = by_style.get(style)
|
||||||
|
if prev is None or ts_f >= prev[0]:
|
||||||
|
by_style[style] = (ts_f, content)
|
||||||
|
for style, (_, content) in by_style.items():
|
||||||
|
if style in {"plan", "memory", "subtask"}:
|
||||||
|
out[style] = content
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
def _build_tools(no_tts: bool, tts_voice: str) -> dict[str, Any]:
|
def _build_tools(no_tts: bool, tts_voice: str) -> dict[str, Any]:
|
||||||
"""Instantiate the tools declared on this dataset/policy."""
|
"""Instantiate the tools declared on this dataset/policy."""
|
||||||
if no_tts:
|
if no_tts:
|
||||||
@@ -364,6 +429,7 @@ def main(argv: list[str] | None = None) -> int:
|
|||||||
)
|
)
|
||||||
|
|
||||||
observation_provider: Callable[[], dict | None] | None = None
|
observation_provider: Callable[[], dict | None] | None = None
|
||||||
|
bootstrap_state: dict[str, str] = {}
|
||||||
if args.dataset_repo_id is not None:
|
if args.dataset_repo_id is not None:
|
||||||
print(
|
print(
|
||||||
f"[smolvla2] streaming observations from {args.dataset_repo_id} "
|
f"[smolvla2] streaming observations from {args.dataset_repo_id} "
|
||||||
@@ -379,6 +445,25 @@ def main(argv: list[str] | None = None) -> int:
|
|||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
device=str(getattr(policy.config, "device", "cpu")),
|
device=str(getattr(policy.config, "device", "cpu")),
|
||||||
)
|
)
|
||||||
|
# Pull the dataset's canonical task + the persistent atoms in
|
||||||
|
# force at the chosen start frame. The model is heavily
|
||||||
|
# memorised on the *exact* training prompts (task wording,
|
||||||
|
# current plan, current memory) — feeding ad-hoc user
|
||||||
|
# alternatives gives it nothing to recall against, so it
|
||||||
|
# collapses to its dominant training mode (VQA JSON). Reading
|
||||||
|
# the canonical state straight from the dataset gives the
|
||||||
|
# runtime a starting point that lines up with training.
|
||||||
|
bootstrap_state = _bootstrap_state_from_dataset(
|
||||||
|
dataset_repo_id=args.dataset_repo_id,
|
||||||
|
episode=args.dataset_episode,
|
||||||
|
start_frame=args.dataset_start_frame,
|
||||||
|
)
|
||||||
|
if bootstrap_state.get("task") and not args.task:
|
||||||
|
args.task = bootstrap_state["task"]
|
||||||
|
print(
|
||||||
|
f"[smolvla2] using canonical task from dataset: {args.task!r}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
tools = _build_tools(args.no_tts, args.tts_voice)
|
tools = _build_tools(args.no_tts, args.tts_voice)
|
||||||
if tools:
|
if tools:
|
||||||
@@ -411,6 +496,17 @@ def main(argv: list[str] | None = None) -> int:
|
|||||||
)
|
)
|
||||||
if args.task:
|
if args.task:
|
||||||
runtime.set_task(args.task)
|
runtime.set_task(args.task)
|
||||||
|
# Bootstrap plan/memory from the dataset so the first prompt the
|
||||||
|
# runtime builds matches what training rendered (task + active
|
||||||
|
# plan + active memory). Without this the runtime starts with
|
||||||
|
# plan/memory empty, which only matched the very-early frames in
|
||||||
|
# training and is an out-of-distribution prompt for the rest.
|
||||||
|
if bootstrap_state.get("plan"):
|
||||||
|
runtime.state["current_plan"] = bootstrap_state["plan"]
|
||||||
|
if bootstrap_state.get("memory"):
|
||||||
|
runtime.state["current_memory"] = bootstrap_state["memory"]
|
||||||
|
if bootstrap_state.get("subtask"):
|
||||||
|
runtime.state["current_subtask"] = bootstrap_state["subtask"]
|
||||||
|
|
||||||
return _run_repl(runtime, initial_task=args.task, max_ticks=args.max_ticks)
|
return _run_repl(runtime, initial_task=args.task, max_ticks=args.max_ticks)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user