fix(smolvla2): wrap robot obs in EnvTransition before preprocessor

The policy preprocessor pipeline is transition-shaped — its steps
read ``TransitionKey.OBSERVATION`` off an ``EnvTransition`` dict, not
a flat ``RobotObservation`` dict. Passing the raw observation through
made every step bail with
``ObservationProcessorStep requires an observation in the transition``,
which the runtime swallowed at warning level. ``select_message`` then
got called with no ``observation.images.*`` features and crashed
with ``All image features are missing from the batch``.

Mirror ``lerobot-record``'s preamble:
  1. ``prepare_observation_for_inference`` → numpy → torch, ``CHW``
     image layout, ``[0,1]`` scaling, add batch dim, move to device.
  2. Wrap into an ``EnvTransition`` (``{TransitionKey.OBSERVATION.value:
     ...}`` plus ``COMPLEMENTARY_DATA: {}`` and ``None``s for the rest)
     so transition-aware steps see the keys they expect.
  3. Run preprocessor.
  4. Unwrap the transition's ``OBSERVATION`` slot to get the final
     flat dict the policy's ``select_action`` / ``select_message``
     consume.

Image features now reach the policy; the autonomous loop produces
real actions instead of swallowing warnings every tick.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-12 14:44:24 +02:00
parent 41095e3cc3
commit afe40a016b
+40 -15
View File
@@ -570,14 +570,25 @@ def _build_robot_observation_provider(
) -> Callable[[], dict | None]:
"""Closure that reads from the robot, runs the policy preprocessor.
Each call: ``robot.get_observation()`` → wrap as a flat sample dict
→ drop language columns (the runtime drives messages itself) →
preprocessor (rename, batch dim, normalise, device-place) → return
the observation batch ready for ``policy.select_action`` and
``policy.select_message``.
Each call: ``robot.get_observation()`` (raw numpy dict) →
``prepare_observation_for_inference`` (tensor / batch dim / device) →
wrap in an ``EnvTransition`` (the preprocessor pipeline is
transition-shaped, keyed by ``TransitionKey``) → preprocessor
(rename, render-messages no-op when no language columns, chat
tokenizer no-op when no messages, normalise) → unwrap and return
the flat observation batch ``policy.select_action`` /
``policy.select_message`` consume.
"""
import torch # noqa: PLC0415
from lerobot.policies.utils import prepare_observation_for_inference # noqa: PLC0415
from lerobot.types import TransitionKey # noqa: PLC0415
torch_device = torch.device(device) if isinstance(device, str) else device
robot_type = getattr(robot, "robot_type", None) or getattr(
getattr(robot, "config", None), "type", None
)
def _provider() -> dict | None:
try:
raw = robot.get_observation()
@@ -585,30 +596,44 @@ def _build_robot_observation_provider(
logger.warning("robot.get_observation failed: %s", exc)
return None
sample: dict[str, Any] = dict(raw)
if task:
sample.setdefault("task", task)
# The render step expects either both language columns or
# neither — runtime supplies messages itself, so make sure
# nothing leaks through.
# Strip language-column leakage just in case (the runtime
# supplies messages itself).
for k in ("language_persistent", "language_events"):
sample.pop(k, None)
raw.pop(k, None)
try:
obs_tensors = prepare_observation_for_inference(
raw, torch_device, task=task, robot_type=robot_type
)
except Exception as exc: # noqa: BLE001
logger.warning("prepare_observation_for_inference failed: %s", exc)
return None
if preprocessor is not None:
transition: dict[str, Any] = {
TransitionKey.OBSERVATION.value: obs_tensors,
TransitionKey.ACTION.value: None,
TransitionKey.REWARD.value: None,
TransitionKey.DONE.value: None,
TransitionKey.TRUNCATED.value: None,
TransitionKey.INFO.value: None,
TransitionKey.COMPLEMENTARY_DATA.value: {},
}
try:
sample = preprocessor(sample)
transition = preprocessor(transition)
except Exception as exc: # noqa: BLE001
logger.warning("preprocessor failed on robot observation: %s", exc)
return None
obs_tensors = transition.get(TransitionKey.OBSERVATION.value) or {}
observation = {
k: v
for k, v in sample.items()
for k, v in obs_tensors.items()
if isinstance(k, str) and k.startswith("observation.")
}
for k, v in list(observation.items()):
if isinstance(v, torch.Tensor):
observation[k] = v.to(device)
observation[k] = v.to(torch_device)
return observation
return _provider