mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
feat(smolvla2): autonomous robot mode in lerobot-smolvla2-runtime
The runtime CLI was deliberately scoped to dry-run only: it
hard-coded ``robot_executor=None`` and printed a "real-robot
integration is a follow-up" warning even when ``--no_robot`` was
omitted. The runtime *engine* was already structured for real-robot
operation (separate ``LowLevelForward`` chunk-rate generation +
``DispatchAction`` ctrl-rate dispatch with a ``robot_executor``
hook); only the wiring was missing.
Add the wiring:
* ``_load_policy_and_preprocessor`` now also returns the
postprocessor (action denormaliser).
* ``--robot.type`` / ``--robot.port`` / ``--robot.id`` /
``--robot.cameras`` (JSON) build a ``Robot`` via
``make_robot_from_config`` and connect it.
* ``_build_robot_observation_provider`` reads
``robot.get_observation()`` each call, drops the language
columns (runtime drives messages itself), and runs the policy's
preprocessor (rename → batch → device → normalise).
* ``_build_robot_action_executor`` postprocesses the policy's
action tensor (denormalise), converts to the ``{joint: value}``
dict via ``make_robot_action(action, ds_meta.features)``, and
calls ``robot.send_action(...)``. Optional ``--max_action_norm``
safety clip rejects ticks whose action L2 norm exceeds the
threshold (kill-switch when bringing up a new robot).
* ``_run_autonomous`` runs ``runtime.run()`` in a background
thread (the policy must keep generating chunks at chunk_hz and
dispatching at ctrl_hz regardless of stdin) and handles user
interjections / VQA queries from the foreground stdin loop.
Confirmation prompt before start (skip with ``--auto_start``);
Ctrl+C stops the thread and disconnects the robot cleanly.
* Autonomous mode requires ``--dataset.repo_id`` for action stats
/ feature shapes — pass the same dataset the policy was trained
on. The bootstrap path that pulls canonical task / plan / memory
runs in both REPL and autonomous modes so the model's first
prompt matches training distribution.
Dry-run REPL behaviour is unchanged when ``--robot.type`` is not
passed.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -131,6 +131,75 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Skip robot connection — language-only / dry-run mode.",
|
help="Skip robot connection — language-only / dry-run mode.",
|
||||||
)
|
)
|
||||||
|
# --- Real-robot mode args ----------------------------------------
|
||||||
|
# Setting ``--robot.type`` flips the runtime into autonomous mode:
|
||||||
|
# it connects to the robot, builds an observation provider that
|
||||||
|
# reads ``robot.get_observation()`` instead of dataset frames, and
|
||||||
|
# an action executor that postprocesses (denormalises) the policy's
|
||||||
|
# output and calls ``robot.send_action(...)`` at ``--ctrl_hz``. The
|
||||||
|
# high-level REPL-style stdin still works in a background thread
|
||||||
|
# for interjections / VQA.
|
||||||
|
p.add_argument(
|
||||||
|
"--robot.type",
|
||||||
|
dest="robot_type",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Robot config choice (e.g. ``so101``, ``so101_follower``). "
|
||||||
|
"When set, the runtime drives the actual robot at "
|
||||||
|
"``--ctrl_hz`` instead of running the dataset-driven dry-run "
|
||||||
|
"REPL. Implies ``--autonomous`` unless ``--no_robot`` is also "
|
||||||
|
"passed (in which case the flag is ignored). See "
|
||||||
|
"``lerobot.robots`` for available choices."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--robot.port",
|
||||||
|
dest="robot_port",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Serial port for the robot (e.g. ``/dev/tty.usbmodem...``).",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--robot.id",
|
||||||
|
dest="robot_id",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Optional robot identifier (passed through to ``RobotConfig.id``).",
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--robot.cameras",
|
||||||
|
dest="robot_cameras",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Optional JSON dict describing camera configs to attach to "
|
||||||
|
"the robot (e.g. ``'{\"top\": {\"type\": \"opencv\", \"index\": 0}}'``). "
|
||||||
|
"Camera keys MUST match the ``observation.images.*`` features "
|
||||||
|
"the policy was trained on."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--max_action_norm",
|
||||||
|
dest="max_action_norm",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Safety clip: reject any individual action whose L2 norm "
|
||||||
|
"exceeds this value. Default ``None`` = no clipping. Useful "
|
||||||
|
"as a kill-switch when bringing up a new robot/task pair."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
p.add_argument(
|
||||||
|
"--auto_start",
|
||||||
|
action="store_true",
|
||||||
|
help=(
|
||||||
|
"Skip the ``Press ENTER to start`` confirmation prompt before "
|
||||||
|
"the autonomous control loop begins. Off by default — having "
|
||||||
|
"to confirm catches a lot of stupid mistakes (wrong policy, "
|
||||||
|
"wrong robot, robot not at home pose)."
|
||||||
|
),
|
||||||
|
)
|
||||||
p.add_argument(
|
p.add_argument(
|
||||||
"--no_tts",
|
"--no_tts",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
@@ -168,21 +237,13 @@ def _parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
|||||||
def _load_policy_and_preprocessor(
|
def _load_policy_and_preprocessor(
|
||||||
policy_path: str,
|
policy_path: str,
|
||||||
dataset_repo_id: str | None,
|
dataset_repo_id: str | None,
|
||||||
) -> tuple[Any, Any, Any]:
|
) -> tuple[Any, Any, Any, Any]:
|
||||||
"""Load a SmolVLA2 checkpoint (local path or Hub repo id).
|
"""Load a SmolVLA2 checkpoint (local path or Hub repo id).
|
||||||
|
|
||||||
When ``dataset_repo_id`` is provided, the dataset's metadata is used
|
Returns ``(policy, preprocessor, postprocessor, ds_meta)``.
|
||||||
to derive policy features (matching the standard
|
``preprocessor`` / ``postprocessor`` / ``ds_meta`` are ``None``
|
||||||
``make_policy(cfg, ds_meta=...)`` flow used by ``lerobot-train`` and
|
when no dataset is provided (rare — needed for autonomous robot
|
||||||
``lerobot-record``). When it isn't, we fall back to instantiating
|
mode to have action-denormalisation stats).
|
||||||
the policy directly via ``from_pretrained`` — this skips the
|
|
||||||
feature-derivation path that ``make_policy`` insists on, but also
|
|
||||||
means we can't load the saved preprocessor pipeline (which depends
|
|
||||||
on ``input_features`` / ``output_features``). For inference-only
|
|
||||||
dry-runs this is fine; the policy still loads.
|
|
||||||
|
|
||||||
Returns ``(policy, preprocessor, ds_meta)`` where ``preprocessor``
|
|
||||||
and ``ds_meta`` may be ``None`` if no dataset was provided.
|
|
||||||
"""
|
"""
|
||||||
from lerobot.configs import PreTrainedConfig # noqa: PLC0415
|
from lerobot.configs import PreTrainedConfig # noqa: PLC0415
|
||||||
from lerobot.policies.factory import make_policy, make_pre_post_processors # noqa: PLC0415
|
from lerobot.policies.factory import make_policy, make_pre_post_processors # noqa: PLC0415
|
||||||
@@ -192,34 +253,22 @@ def _load_policy_and_preprocessor(
|
|||||||
|
|
||||||
ds_meta = None
|
ds_meta = None
|
||||||
preprocessor = None
|
preprocessor = None
|
||||||
|
postprocessor = None
|
||||||
if dataset_repo_id is not None:
|
if dataset_repo_id is not None:
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata # noqa: PLC0415
|
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata # noqa: PLC0415
|
||||||
|
|
||||||
ds_meta = LeRobotDatasetMetadata(dataset_repo_id)
|
ds_meta = LeRobotDatasetMetadata(dataset_repo_id)
|
||||||
policy = make_policy(cfg, ds_meta=ds_meta)
|
policy = make_policy(cfg, ds_meta=ds_meta)
|
||||||
# NOTE: we deliberately pass ``pretrained_path=None`` here even
|
# ``pretrained_path=None`` rebuilds fresh — the saved
|
||||||
# though the checkpoint ships a ``policy_preprocessor.json``.
|
# ``policy_preprocessor.json`` doesn't round-trip
|
||||||
# ``RenderMessagesStep`` carries a ``TrainingRecipe`` field that
|
# ``RenderMessagesStep.recipe``. Stats come from the dataset
|
||||||
# isn't faithfully serialized into that JSON, so the saved
|
# the user is feeding through, so normalisation is consistent.
|
||||||
# pipeline can't currently be round-tripped via
|
preprocessor, postprocessor = make_pre_post_processors(
|
||||||
# ``PolicyProcessorPipeline.from_pretrained`` — it crashes with
|
|
||||||
# ``RenderMessagesStep.__init__() missing 1 required argument:
|
|
||||||
# 'recipe'``. Building fresh from ``cfg`` re-runs
|
|
||||||
# ``make_smolvla2_pre_post_processors``, which loads the recipe
|
|
||||||
# YAML referenced by ``cfg.recipe_path`` and wires it back into
|
|
||||||
# ``RenderMessagesStep`` correctly. Normalization stats come
|
|
||||||
# from ``ds_meta.stats`` (the same dataset the user is feeding
|
|
||||||
# into the runtime), so no quality loss in practice.
|
|
||||||
preprocessor, _ = make_pre_post_processors(
|
|
||||||
cfg,
|
cfg,
|
||||||
pretrained_path=None,
|
pretrained_path=None,
|
||||||
dataset_stats=ds_meta.stats,
|
dataset_stats=ds_meta.stats,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# No dataset: instantiate the policy class directly so we don't
|
|
||||||
# need ds_meta. This bypasses ``make_policy``'s feature-shape
|
|
||||||
# derivation, which is fine for a pretrained checkpoint where
|
|
||||||
# the saved config already carries those shapes.
|
|
||||||
from lerobot.policies.factory import get_policy_class # noqa: PLC0415
|
from lerobot.policies.factory import get_policy_class # noqa: PLC0415
|
||||||
|
|
||||||
policy_cls = get_policy_class(cfg.type)
|
policy_cls = get_policy_class(cfg.type)
|
||||||
@@ -227,7 +276,7 @@ def _load_policy_and_preprocessor(
|
|||||||
policy.to(cfg.device)
|
policy.to(cfg.device)
|
||||||
|
|
||||||
policy.eval()
|
policy.eval()
|
||||||
return policy, preprocessor, ds_meta
|
return policy, preprocessor, postprocessor, ds_meta
|
||||||
|
|
||||||
|
|
||||||
def _build_observation_provider(
|
def _build_observation_provider(
|
||||||
@@ -372,6 +421,230 @@ def _bootstrap_state_from_dataset(
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _build_robot(
|
||||||
|
*,
|
||||||
|
robot_type: str,
|
||||||
|
robot_port: str | None,
|
||||||
|
robot_id: str | None,
|
||||||
|
robot_cameras_json: str | None,
|
||||||
|
):
|
||||||
|
"""Build and connect a robot from CLI args.
|
||||||
|
|
||||||
|
Mirrors how ``lerobot-record`` builds a robot but takes the args
|
||||||
|
flat from argparse instead of through draccus, so the runtime
|
||||||
|
keeps its plain ``--key=value`` CLI surface.
|
||||||
|
"""
|
||||||
|
import json # noqa: PLC0415
|
||||||
|
|
||||||
|
from lerobot.robots import ( # noqa: PLC0415
|
||||||
|
RobotConfig,
|
||||||
|
make_robot_from_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
cls = RobotConfig.get_choice_class(robot_type)
|
||||||
|
kwargs: dict[str, Any] = {}
|
||||||
|
if robot_port:
|
||||||
|
kwargs["port"] = robot_port
|
||||||
|
if robot_id:
|
||||||
|
kwargs["id"] = robot_id
|
||||||
|
if robot_cameras_json:
|
||||||
|
try:
|
||||||
|
kwargs["cameras"] = json.loads(robot_cameras_json)
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
raise ValueError(
|
||||||
|
f"--robot.cameras must be a JSON object, got {robot_cameras_json!r}: {exc}"
|
||||||
|
) from exc
|
||||||
|
cfg = cls(**kwargs)
|
||||||
|
robot = make_robot_from_config(cfg)
|
||||||
|
robot.connect()
|
||||||
|
return robot
|
||||||
|
|
||||||
|
|
||||||
|
def _build_robot_observation_provider(
|
||||||
|
*,
|
||||||
|
robot,
|
||||||
|
preprocessor: Any,
|
||||||
|
device: str,
|
||||||
|
task: str | None,
|
||||||
|
) -> Callable[[], dict | None]:
|
||||||
|
"""Closure that reads from the robot, runs the policy preprocessor.
|
||||||
|
|
||||||
|
Each call: ``robot.get_observation()`` → wrap as a flat sample dict
|
||||||
|
→ drop language columns (the runtime drives messages itself) →
|
||||||
|
preprocessor (rename, batch dim, normalise, device-place) → return
|
||||||
|
the observation batch ready for ``policy.select_action`` and
|
||||||
|
``policy.select_message``.
|
||||||
|
"""
|
||||||
|
import torch # noqa: PLC0415
|
||||||
|
|
||||||
|
def _provider() -> dict | None:
|
||||||
|
try:
|
||||||
|
raw = robot.get_observation()
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.warning("robot.get_observation failed: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
sample: dict[str, Any] = dict(raw)
|
||||||
|
if task:
|
||||||
|
sample.setdefault("task", task)
|
||||||
|
# The render step expects either both language columns or
|
||||||
|
# neither — runtime supplies messages itself, so make sure
|
||||||
|
# nothing leaks through.
|
||||||
|
for k in ("language_persistent", "language_events"):
|
||||||
|
sample.pop(k, None)
|
||||||
|
|
||||||
|
if preprocessor is not None:
|
||||||
|
try:
|
||||||
|
sample = preprocessor(sample)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.warning("preprocessor failed on robot observation: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
observation = {
|
||||||
|
k: v
|
||||||
|
for k, v in sample.items()
|
||||||
|
if isinstance(k, str) and k.startswith("observation.")
|
||||||
|
}
|
||||||
|
for k, v in list(observation.items()):
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
observation[k] = v.to(device)
|
||||||
|
return observation
|
||||||
|
|
||||||
|
return _provider
|
||||||
|
|
||||||
|
|
||||||
|
def _build_robot_action_executor(
|
||||||
|
*,
|
||||||
|
robot,
|
||||||
|
postprocessor: Any,
|
||||||
|
ds_meta: Any,
|
||||||
|
max_action_norm: float | None,
|
||||||
|
) -> Callable[[Any], None]:
|
||||||
|
"""Closure that postprocesses an action and dispatches to the robot.
|
||||||
|
|
||||||
|
Mirrors ``lerobot-record``'s ``predict_action`` tail: postprocess
|
||||||
|
(denormalise) → ``make_robot_action`` (tensor → ``{joint: value}``
|
||||||
|
dict) → ``robot.send_action(...)``. Optional safety clip on the
|
||||||
|
action's L2 norm acts as a kill switch when bringing up a new
|
||||||
|
robot/task pair.
|
||||||
|
"""
|
||||||
|
import torch # noqa: PLC0415
|
||||||
|
|
||||||
|
from lerobot.policies.utils import make_robot_action # noqa: PLC0415
|
||||||
|
|
||||||
|
def _executor(action: Any) -> None:
|
||||||
|
try:
|
||||||
|
if postprocessor is not None:
|
||||||
|
action = postprocessor(action)
|
||||||
|
if isinstance(action, torch.Tensor):
|
||||||
|
if max_action_norm is not None:
|
||||||
|
norm = float(action.float().norm().item())
|
||||||
|
if norm > max_action_norm:
|
||||||
|
logger.warning(
|
||||||
|
"action norm %.3f > max_action_norm=%.3f — "
|
||||||
|
"rejecting tick",
|
||||||
|
norm, max_action_norm,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if action.ndim > 1 and action.shape[0] == 1:
|
||||||
|
action = action.squeeze(0)
|
||||||
|
action_dict = make_robot_action(action, ds_meta.features)
|
||||||
|
elif isinstance(action, dict):
|
||||||
|
action_dict = action
|
||||||
|
else:
|
||||||
|
logger.warning("unsupported action type %r — skipping", type(action))
|
||||||
|
return
|
||||||
|
robot.send_action(action_dict)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.error("robot.send_action failed: %s", exc, exc_info=True)
|
||||||
|
|
||||||
|
return _executor
|
||||||
|
|
||||||
|
|
||||||
|
def _run_autonomous(
|
||||||
|
runtime: Any,
|
||||||
|
*,
|
||||||
|
robot,
|
||||||
|
auto_start: bool,
|
||||||
|
initial_task: str | None,
|
||||||
|
max_ticks: int | None,
|
||||||
|
) -> int:
|
||||||
|
"""Drive the runtime continuously at ``ctrl_hz`` while accepting
|
||||||
|
stdin events in the foreground.
|
||||||
|
|
||||||
|
Different from ``_run_repl`` (dataset dry-run): the policy needs
|
||||||
|
to keep generating action chunks at ``chunk_hz`` and dispatching
|
||||||
|
them at ``ctrl_hz`` regardless of whether the user is typing, so
|
||||||
|
``runtime.run()`` runs in a background thread and stdin handling
|
||||||
|
happens here in the main thread.
|
||||||
|
"""
|
||||||
|
import threading # noqa: PLC0415
|
||||||
|
import time # noqa: PLC0415
|
||||||
|
|
||||||
|
if not auto_start:
|
||||||
|
try:
|
||||||
|
input(
|
||||||
|
"[smolvla2] Robot connected. Press ENTER to start the autonomous "
|
||||||
|
"control loop, Ctrl+C to abort. "
|
||||||
|
)
|
||||||
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
print("\n[smolvla2] aborted before start", flush=True)
|
||||||
|
return 130
|
||||||
|
|
||||||
|
if initial_task:
|
||||||
|
runtime.set_task(initial_task)
|
||||||
|
|
||||||
|
thread = threading.Thread(
|
||||||
|
target=runtime.run,
|
||||||
|
kwargs={"max_ticks": max_ticks},
|
||||||
|
name="smolvla2-runtime-loop",
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
thread.start()
|
||||||
|
print(
|
||||||
|
"[smolvla2] autonomous loop running. Type interjections / "
|
||||||
|
"questions on stdin (Ctrl+C to stop).",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while thread.is_alive():
|
||||||
|
try:
|
||||||
|
line = input("> ").strip()
|
||||||
|
except EOFError:
|
||||||
|
break
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
lower = line.lower()
|
||||||
|
if lower in {"stop", "quit", "exit"}:
|
||||||
|
break
|
||||||
|
if not runtime.state.get("task"):
|
||||||
|
runtime.set_task(line[5:].strip() if lower.startswith("task:") else line)
|
||||||
|
continue
|
||||||
|
if lower.endswith("?"):
|
||||||
|
runtime.state["recent_vqa_query"] = line
|
||||||
|
runtime.state.setdefault("events_this_tick", []).append("user_vqa_query")
|
||||||
|
else:
|
||||||
|
runtime.state["recent_interjection"] = line
|
||||||
|
runtime.state.setdefault("events_this_tick", []).append("user_interjection")
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n[smolvla2] interrupt — stopping", flush=True)
|
||||||
|
finally:
|
||||||
|
runtime.stop()
|
||||||
|
# Give the loop a moment to drain.
|
||||||
|
for _ in range(10):
|
||||||
|
if not thread.is_alive():
|
||||||
|
break
|
||||||
|
time.sleep(0.1)
|
||||||
|
try:
|
||||||
|
robot.disconnect()
|
||||||
|
print("[smolvla2] robot disconnected", flush=True)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
print(f"[smolvla2] WARNING: robot.disconnect raised {exc}", flush=True)
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def _build_tools(no_tts: bool, tts_voice: str) -> dict[str, Any]:
|
def _build_tools(no_tts: bool, tts_voice: str) -> dict[str, Any]:
|
||||||
"""Instantiate the tools declared on this dataset/policy."""
|
"""Instantiate the tools declared on this dataset/policy."""
|
||||||
if no_tts:
|
if no_tts:
|
||||||
@@ -423,14 +696,69 @@ def main(argv: list[str] | None = None) -> int:
|
|||||||
)
|
)
|
||||||
_silence_noisy_loggers()
|
_silence_noisy_loggers()
|
||||||
|
|
||||||
|
autonomous_mode = bool(args.robot_type) and not args.no_robot
|
||||||
|
if autonomous_mode and not args.dataset_repo_id:
|
||||||
|
print(
|
||||||
|
"[smolvla2] ERROR: autonomous robot mode requires --dataset.repo_id "
|
||||||
|
"for action-denormalisation stats and feature shapes. Pass the "
|
||||||
|
"same dataset the policy was trained on.",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
return 2
|
||||||
|
|
||||||
print(f"[smolvla2] loading policy from {args.policy_path}", flush=True)
|
print(f"[smolvla2] loading policy from {args.policy_path}", flush=True)
|
||||||
policy, preprocessor, _ds_meta = _load_policy_and_preprocessor(
|
policy, preprocessor, postprocessor, ds_meta = _load_policy_and_preprocessor(
|
||||||
args.policy_path, args.dataset_repo_id
|
args.policy_path, args.dataset_repo_id
|
||||||
)
|
)
|
||||||
|
|
||||||
observation_provider: Callable[[], dict | None] | None = None
|
# Bootstrap canonical task / plan / memory / subtask from the
|
||||||
|
# dataset whenever one is provided — both REPL dry-run and
|
||||||
|
# autonomous robot mode benefit, since the model is memorised on
|
||||||
|
# the exact training prompts and matching wording is what gets
|
||||||
|
# recall to fire.
|
||||||
bootstrap_state: dict[str, str] = {}
|
bootstrap_state: dict[str, str] = {}
|
||||||
if args.dataset_repo_id is not None:
|
if args.dataset_repo_id is not None:
|
||||||
|
bootstrap_state = _bootstrap_state_from_dataset(
|
||||||
|
dataset_repo_id=args.dataset_repo_id,
|
||||||
|
episode=args.dataset_episode,
|
||||||
|
start_frame=args.dataset_start_frame,
|
||||||
|
)
|
||||||
|
if bootstrap_state.get("task") and not args.task:
|
||||||
|
args.task = bootstrap_state["task"]
|
||||||
|
print(
|
||||||
|
f"[smolvla2] using canonical task from dataset: {args.task!r}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
observation_provider: Callable[[], dict | None] | None = None
|
||||||
|
robot_executor: Callable[[Any], None] | None = None
|
||||||
|
robot = None
|
||||||
|
|
||||||
|
if autonomous_mode:
|
||||||
|
print(
|
||||||
|
f"[smolvla2] connecting to robot.type={args.robot_type} "
|
||||||
|
f"port={args.robot_port}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
robot = _build_robot(
|
||||||
|
robot_type=args.robot_type,
|
||||||
|
robot_port=args.robot_port,
|
||||||
|
robot_id=args.robot_id,
|
||||||
|
robot_cameras_json=args.robot_cameras,
|
||||||
|
)
|
||||||
|
observation_provider = _build_robot_observation_provider(
|
||||||
|
robot=robot,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
device=str(getattr(policy.config, "device", "cpu")),
|
||||||
|
task=args.task,
|
||||||
|
)
|
||||||
|
robot_executor = _build_robot_action_executor(
|
||||||
|
robot=robot,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
ds_meta=ds_meta,
|
||||||
|
max_action_norm=args.max_action_norm,
|
||||||
|
)
|
||||||
|
elif args.dataset_repo_id is not None:
|
||||||
print(
|
print(
|
||||||
f"[smolvla2] streaming observations from {args.dataset_repo_id} "
|
f"[smolvla2] streaming observations from {args.dataset_repo_id} "
|
||||||
f"episode={args.dataset_episode} "
|
f"episode={args.dataset_episode} "
|
||||||
@@ -445,38 +773,11 @@ def main(argv: list[str] | None = None) -> int:
|
|||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
device=str(getattr(policy.config, "device", "cpu")),
|
device=str(getattr(policy.config, "device", "cpu")),
|
||||||
)
|
)
|
||||||
# Pull the dataset's canonical task + the persistent atoms in
|
|
||||||
# force at the chosen start frame. The model is heavily
|
|
||||||
# memorised on the *exact* training prompts (task wording,
|
|
||||||
# current plan, current memory) — feeding ad-hoc user
|
|
||||||
# alternatives gives it nothing to recall against, so it
|
|
||||||
# collapses to its dominant training mode (VQA JSON). Reading
|
|
||||||
# the canonical state straight from the dataset gives the
|
|
||||||
# runtime a starting point that lines up with training.
|
|
||||||
bootstrap_state = _bootstrap_state_from_dataset(
|
|
||||||
dataset_repo_id=args.dataset_repo_id,
|
|
||||||
episode=args.dataset_episode,
|
|
||||||
start_frame=args.dataset_start_frame,
|
|
||||||
)
|
|
||||||
if bootstrap_state.get("task") and not args.task:
|
|
||||||
args.task = bootstrap_state["task"]
|
|
||||||
print(
|
|
||||||
f"[smolvla2] using canonical task from dataset: {args.task!r}",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
tools = _build_tools(args.no_tts, args.tts_voice)
|
tools = _build_tools(args.no_tts, args.tts_voice)
|
||||||
if tools:
|
if tools:
|
||||||
print(f"[smolvla2] tools loaded: {list(tools)}", flush=True)
|
print(f"[smolvla2] tools loaded: {list(tools)}", flush=True)
|
||||||
|
|
||||||
robot_executor = None
|
|
||||||
if not args.no_robot:
|
|
||||||
print(
|
|
||||||
"[smolvla2] WARNING: real-robot integration is a follow-up. "
|
|
||||||
"Running in dry-run mode for now (no actions executed).",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
from lerobot.policies.smolvla2.inference import SmolVLA2Runtime # noqa: PLC0415
|
from lerobot.policies.smolvla2.inference import SmolVLA2Runtime # noqa: PLC0415
|
||||||
|
|
||||||
runtime = SmolVLA2Runtime(
|
runtime = SmolVLA2Runtime(
|
||||||
@@ -485,10 +786,9 @@ def main(argv: list[str] | None = None) -> int:
|
|||||||
observation_provider=observation_provider,
|
observation_provider=observation_provider,
|
||||||
robot_executor=robot_executor,
|
robot_executor=robot_executor,
|
||||||
# No background event collector — the REPL drives ticks
|
# No background event collector — the REPL drives ticks
|
||||||
# synchronously after each user input. The runtime's own
|
# synchronously after each user input (REPL mode). Autonomous
|
||||||
# ``run()`` loop is bypassed here in favour of ``step_once()``
|
# mode runs ``runtime.run()`` in a thread; stdin events are
|
||||||
# so the input prompt and the live state panel co-exist
|
# injected from the foreground.
|
||||||
# cleanly.
|
|
||||||
event_collector=None,
|
event_collector=None,
|
||||||
chunk_hz=args.chunk_hz,
|
chunk_hz=args.chunk_hz,
|
||||||
ctrl_hz=args.ctrl_hz,
|
ctrl_hz=args.ctrl_hz,
|
||||||
@@ -496,10 +796,10 @@ def main(argv: list[str] | None = None) -> int:
|
|||||||
)
|
)
|
||||||
if args.task:
|
if args.task:
|
||||||
runtime.set_task(args.task)
|
runtime.set_task(args.task)
|
||||||
# Bootstrap plan/memory from the dataset so the first prompt the
|
# Seed plan/memory/subtask so the first prompt the runtime builds
|
||||||
# runtime builds matches what training rendered (task + active
|
# mirrors what training rendered (task + active plan + active
|
||||||
# plan + active memory). Without this the runtime starts with
|
# memory + optional current subtask). Without this the runtime
|
||||||
# plan/memory empty, which only matched the very-early frames in
|
# starts empty, which only matched the very-early frames during
|
||||||
# training and is an out-of-distribution prompt for the rest.
|
# training and is an out-of-distribution prompt for the rest.
|
||||||
if bootstrap_state.get("plan"):
|
if bootstrap_state.get("plan"):
|
||||||
runtime.state["current_plan"] = bootstrap_state["plan"]
|
runtime.state["current_plan"] = bootstrap_state["plan"]
|
||||||
@@ -508,6 +808,14 @@ def main(argv: list[str] | None = None) -> int:
|
|||||||
if bootstrap_state.get("subtask"):
|
if bootstrap_state.get("subtask"):
|
||||||
runtime.state["current_subtask"] = bootstrap_state["subtask"]
|
runtime.state["current_subtask"] = bootstrap_state["subtask"]
|
||||||
|
|
||||||
|
if autonomous_mode:
|
||||||
|
return _run_autonomous(
|
||||||
|
runtime,
|
||||||
|
robot=robot,
|
||||||
|
auto_start=args.auto_start,
|
||||||
|
initial_task=args.task,
|
||||||
|
max_ticks=args.max_ticks,
|
||||||
|
)
|
||||||
return _run_repl(runtime, initial_task=args.task, max_ticks=args.max_ticks)
|
return _run_repl(runtime, initial_task=args.task, max_ticks=args.max_ticks)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user