fix(smolvla2): use build_inference_frame for raw robot observations

``robot.get_observation()`` on omx_follower (and most lerobot robots)
returns:

  * per-joint scalar floats with ``.pos`` suffix
    (``shoulder_pan.pos: 0.123``, ``shoulder_lift.pos: 0.456``, ...)
  * per-camera ndarrays keyed by the camera config name (``wrist:
    ndarray(H,W,3)``)

But the trained policy expects:

  * single ``observation.state: tensor[N_joints]`` vector
  * image keys prefixed: ``observation.images.<cam_key>:
    tensor[1, 3, H, W]``

``prepare_observation_for_inference`` only handles the tensor /
batch-dim / device step — it crashes on scalar floats with
``expected np.ndarray (got float)``. The right helper is
``build_inference_frame`` which uses the dataset's feature schema
(``ds_meta.features``) to:

  1. extract the right raw keys per dataset feature,
  2. fold ``shoulder_pan.pos`` / ``shoulder_lift.pos`` / ...
     into a single ``observation.state`` ndarray,
  3. prefix camera keys with ``observation.images.``,
  4. delegate to ``prepare_observation_for_inference`` for the
     tensor / batch / device step.

Pass ``ds_meta.features`` into the observation provider and switch
to ``build_inference_frame`` when available; fall back to the bare
``prepare_observation_for_inference`` only when no dataset is
provided (rare — autonomous mode already requires it).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-12 14:47:59 +02:00
parent afe40a016b
commit 992d13d4e9
+33 -13
View File
@@ -567,21 +567,27 @@ def _build_robot_observation_provider(
preprocessor: Any,
device: str,
task: str | None,
ds_features: dict[str, Any] | None,
) -> Callable[[], dict | None]:
"""Closure that reads from the robot, runs the policy preprocessor.
Each call: ``robot.get_observation()`` (raw numpy dict) →
``prepare_observation_for_inference`` (tensor / batch dim / device) →
wrap in an ``EnvTransition`` (the preprocessor pipeline is
transition-shaped, keyed by ``TransitionKey``) → preprocessor
(rename, render-messages no-op when no language columns, chat
tokenizer no-op when no messages, normalise) → unwrap and return
the flat observation batch ``policy.select_action`` /
``policy.select_message`` consume.
Each call: ``robot.get_observation()`` (raw per-joint + per-camera
dict, possibly with scalar floats) → ``build_inference_frame``
(extract the keys the dataset declared, reshape per-joint floats
into a single ``observation.state`` vector, prefix camera keys
with ``observation.images.``, convert to tensors with batch dim
on device) → wrap in an ``EnvTransition`` (the preprocessor
pipeline is transition-shaped, keyed by ``TransitionKey``) →
preprocessor (rename, normalise) → unwrap and return the flat
observation batch ``policy.select_action`` / ``policy.select_message``
consume.
"""
import torch # noqa: PLC0415
from lerobot.policies.utils import prepare_observation_for_inference # noqa: PLC0415
from lerobot.policies.utils import ( # noqa: PLC0415
build_inference_frame,
prepare_observation_for_inference,
)
from lerobot.types import TransitionKey # noqa: PLC0415
torch_device = torch.device(device) if isinstance(device, str) else device
@@ -602,11 +608,24 @@ def _build_robot_observation_provider(
raw.pop(k, None)
try:
obs_tensors = prepare_observation_for_inference(
raw, torch_device, task=task, robot_type=robot_type
)
if ds_features:
# Use the dataset's feature schema to pick the right
# raw keys and fold per-joint scalars into a single
# ``observation.state`` tensor. Then tensor-ise +
# device-place + add batch dim.
obs_tensors = build_inference_frame(
raw, torch_device, ds_features=ds_features,
task=task, robot_type=robot_type,
)
else:
# No dataset features available — fall back to the
# generic numpy-only path; only works when the robot
# already returns dataset-shaped keys.
obs_tensors = prepare_observation_for_inference(
raw, torch_device, task=task, robot_type=robot_type,
)
except Exception as exc: # noqa: BLE001
logger.warning("prepare_observation_for_inference failed: %s", exc)
logger.warning("observation prep failed: %s", exc)
return None
if preprocessor is not None:
@@ -869,6 +888,7 @@ def main(argv: list[str] | None = None) -> int:
preprocessor=preprocessor,
device=str(getattr(policy.config, "device", "cpu")),
task=args.task,
ds_features=ds_meta.features if ds_meta is not None else None,
)
robot_executor = _build_robot_action_executor(
robot=robot,