diff --git a/src/lerobot/scripts/lerobot_smolvla2_runtime.py b/src/lerobot/scripts/lerobot_smolvla2_runtime.py index c92f26a84..f59120936 100644 --- a/src/lerobot/scripts/lerobot_smolvla2_runtime.py +++ b/src/lerobot/scripts/lerobot_smolvla2_runtime.py @@ -567,21 +567,27 @@ def _build_robot_observation_provider( preprocessor: Any, device: str, task: str | None, + ds_features: dict[str, Any] | None, ) -> Callable[[], dict | None]: """Closure that reads from the robot, runs the policy preprocessor. - Each call: ``robot.get_observation()`` (raw numpy dict) → - ``prepare_observation_for_inference`` (tensor / batch dim / device) → - wrap in an ``EnvTransition`` (the preprocessor pipeline is - transition-shaped, keyed by ``TransitionKey``) → preprocessor - (rename, render-messages no-op when no language columns, chat - tokenizer no-op when no messages, normalise) → unwrap and return - the flat observation batch ``policy.select_action`` / - ``policy.select_message`` consume. + Each call: ``robot.get_observation()`` (raw per-joint + per-camera + dict, possibly with scalar floats) → ``build_inference_frame`` + (extract the keys the dataset declared, reshape per-joint floats + into a single ``observation.state`` vector, prefix camera keys + with ``observation.images.``, convert to tensors with batch dim + on device) → wrap in an ``EnvTransition`` (the preprocessor + pipeline is transition-shaped, keyed by ``TransitionKey``) → + preprocessor (rename, normalise) → unwrap and return the flat + observation batch ``policy.select_action`` / ``policy.select_message`` + consume. """ import torch # noqa: PLC0415 - from lerobot.policies.utils import prepare_observation_for_inference # noqa: PLC0415 + from lerobot.policies.utils import ( # noqa: PLC0415 + build_inference_frame, + prepare_observation_for_inference, + ) from lerobot.types import TransitionKey # noqa: PLC0415 torch_device = torch.device(device) if isinstance(device, str) else device @@ -602,11 +608,24 @@ def _build_robot_observation_provider( raw.pop(k, None) try: - obs_tensors = prepare_observation_for_inference( - raw, torch_device, task=task, robot_type=robot_type - ) + if ds_features: + # Use the dataset's feature schema to pick the right + # raw keys and fold per-joint scalars into a single + # ``observation.state`` tensor. Then tensor-ise + + # device-place + add batch dim. + obs_tensors = build_inference_frame( + raw, torch_device, ds_features=ds_features, + task=task, robot_type=robot_type, + ) + else: + # No dataset features available — fall back to the + # generic numpy-only path; only works when the robot + # already returns dataset-shaped keys. + obs_tensors = prepare_observation_for_inference( + raw, torch_device, task=task, robot_type=robot_type, + ) except Exception as exc: # noqa: BLE001 - logger.warning("prepare_observation_for_inference failed: %s", exc) + logger.warning("observation prep failed: %s", exc) return None if preprocessor is not None: @@ -869,6 +888,7 @@ def main(argv: list[str] | None = None) -> int: preprocessor=preprocessor, device=str(getattr(policy.config, "device", "cpu")), task=args.task, + ds_features=ds_meta.features if ds_meta is not None else None, ) robot_executor = _build_robot_action_executor( robot=robot,