diff --git a/src/lerobot/scripts/lerobot_smolvla2_runtime.py b/src/lerobot/scripts/lerobot_smolvla2_runtime.py index 00c63473c..c92f26a84 100644 --- a/src/lerobot/scripts/lerobot_smolvla2_runtime.py +++ b/src/lerobot/scripts/lerobot_smolvla2_runtime.py @@ -570,14 +570,25 @@ def _build_robot_observation_provider( ) -> Callable[[], dict | None]: """Closure that reads from the robot, runs the policy preprocessor. - Each call: ``robot.get_observation()`` → wrap as a flat sample dict - → drop language columns (the runtime drives messages itself) → - preprocessor (rename, batch dim, normalise, device-place) → return - the observation batch ready for ``policy.select_action`` and - ``policy.select_message``. + Each call: ``robot.get_observation()`` (raw numpy dict) → + ``prepare_observation_for_inference`` (tensor / batch dim / device) → + wrap in an ``EnvTransition`` (the preprocessor pipeline is + transition-shaped, keyed by ``TransitionKey``) → preprocessor + (rename, render-messages no-op when no language columns, chat + tokenizer no-op when no messages, normalise) → unwrap and return + the flat observation batch ``policy.select_action`` / + ``policy.select_message`` consume. """ import torch # noqa: PLC0415 + from lerobot.policies.utils import prepare_observation_for_inference # noqa: PLC0415 + from lerobot.types import TransitionKey # noqa: PLC0415 + + torch_device = torch.device(device) if isinstance(device, str) else device + robot_type = getattr(robot, "robot_type", None) or getattr( + getattr(robot, "config", None), "type", None + ) + def _provider() -> dict | None: try: raw = robot.get_observation() @@ -585,30 +596,44 @@ def _build_robot_observation_provider( logger.warning("robot.get_observation failed: %s", exc) return None - sample: dict[str, Any] = dict(raw) - if task: - sample.setdefault("task", task) - # The render step expects either both language columns or - # neither — runtime supplies messages itself, so make sure - # nothing leaks through. + # Strip language-column leakage just in case (the runtime + # supplies messages itself). for k in ("language_persistent", "language_events"): - sample.pop(k, None) + raw.pop(k, None) + + try: + obs_tensors = prepare_observation_for_inference( + raw, torch_device, task=task, robot_type=robot_type + ) + except Exception as exc: # noqa: BLE001 + logger.warning("prepare_observation_for_inference failed: %s", exc) + return None if preprocessor is not None: + transition: dict[str, Any] = { + TransitionKey.OBSERVATION.value: obs_tensors, + TransitionKey.ACTION.value: None, + TransitionKey.REWARD.value: None, + TransitionKey.DONE.value: None, + TransitionKey.TRUNCATED.value: None, + TransitionKey.INFO.value: None, + TransitionKey.COMPLEMENTARY_DATA.value: {}, + } try: - sample = preprocessor(sample) + transition = preprocessor(transition) except Exception as exc: # noqa: BLE001 logger.warning("preprocessor failed on robot observation: %s", exc) return None + obs_tensors = transition.get(TransitionKey.OBSERVATION.value) or {} observation = { k: v - for k, v in sample.items() + for k, v in obs_tensors.items() if isinstance(k, str) and k.startswith("observation.") } for k, v in list(observation.items()): if isinstance(v, torch.Tensor): - observation[k] = v.to(device) + observation[k] = v.to(torch_device) return observation return _provider