mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 13:09:43 +00:00
fix(smolvla2): wrap robot obs in EnvTransition before preprocessor
The policy preprocessor pipeline is transition-shaped — its steps
read ``TransitionKey.OBSERVATION`` off an ``EnvTransition`` dict, not
a flat ``RobotObservation`` dict. Passing the raw observation through
made every step bail with
``ObservationProcessorStep requires an observation in the transition``,
which the runtime swallowed at warning level. ``select_message`` then
got called with no ``observation.images.*`` features and crashed
with ``All image features are missing from the batch``.
Mirror ``lerobot-record``'s preamble:
1. ``prepare_observation_for_inference`` → numpy → torch, ``CHW``
image layout, ``[0,1]`` scaling, add batch dim, move to device.
2. Wrap into an ``EnvTransition`` (``{TransitionKey.OBSERVATION.value:
...}`` plus ``COMPLEMENTARY_DATA: {}`` and ``None``s for the rest)
so transition-aware steps see the keys they expect.
3. Run preprocessor.
4. Unwrap the transition's ``OBSERVATION`` slot to get the final
flat dict the policy's ``select_action`` / ``select_message``
consume.
Image features now reach the policy; the autonomous loop produces
real actions instead of swallowing warnings every tick.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -570,14 +570,25 @@ def _build_robot_observation_provider(
|
||||
) -> Callable[[], dict | None]:
|
||||
"""Closure that reads from the robot, runs the policy preprocessor.
|
||||
|
||||
Each call: ``robot.get_observation()`` → wrap as a flat sample dict
|
||||
→ drop language columns (the runtime drives messages itself) →
|
||||
preprocessor (rename, batch dim, normalise, device-place) → return
|
||||
the observation batch ready for ``policy.select_action`` and
|
||||
``policy.select_message``.
|
||||
Each call: ``robot.get_observation()`` (raw numpy dict) →
|
||||
``prepare_observation_for_inference`` (tensor / batch dim / device) →
|
||||
wrap in an ``EnvTransition`` (the preprocessor pipeline is
|
||||
transition-shaped, keyed by ``TransitionKey``) → preprocessor
|
||||
(rename, render-messages no-op when no language columns, chat
|
||||
tokenizer no-op when no messages, normalise) → unwrap and return
|
||||
the flat observation batch ``policy.select_action`` /
|
||||
``policy.select_message`` consume.
|
||||
"""
|
||||
import torch # noqa: PLC0415
|
||||
|
||||
from lerobot.policies.utils import prepare_observation_for_inference # noqa: PLC0415
|
||||
from lerobot.types import TransitionKey # noqa: PLC0415
|
||||
|
||||
torch_device = torch.device(device) if isinstance(device, str) else device
|
||||
robot_type = getattr(robot, "robot_type", None) or getattr(
|
||||
getattr(robot, "config", None), "type", None
|
||||
)
|
||||
|
||||
def _provider() -> dict | None:
|
||||
try:
|
||||
raw = robot.get_observation()
|
||||
@@ -585,30 +596,44 @@ def _build_robot_observation_provider(
|
||||
logger.warning("robot.get_observation failed: %s", exc)
|
||||
return None
|
||||
|
||||
sample: dict[str, Any] = dict(raw)
|
||||
if task:
|
||||
sample.setdefault("task", task)
|
||||
# The render step expects either both language columns or
|
||||
# neither — runtime supplies messages itself, so make sure
|
||||
# nothing leaks through.
|
||||
# Strip language-column leakage just in case (the runtime
|
||||
# supplies messages itself).
|
||||
for k in ("language_persistent", "language_events"):
|
||||
sample.pop(k, None)
|
||||
raw.pop(k, None)
|
||||
|
||||
try:
|
||||
obs_tensors = prepare_observation_for_inference(
|
||||
raw, torch_device, task=task, robot_type=robot_type
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("prepare_observation_for_inference failed: %s", exc)
|
||||
return None
|
||||
|
||||
if preprocessor is not None:
|
||||
transition: dict[str, Any] = {
|
||||
TransitionKey.OBSERVATION.value: obs_tensors,
|
||||
TransitionKey.ACTION.value: None,
|
||||
TransitionKey.REWARD.value: None,
|
||||
TransitionKey.DONE.value: None,
|
||||
TransitionKey.TRUNCATED.value: None,
|
||||
TransitionKey.INFO.value: None,
|
||||
TransitionKey.COMPLEMENTARY_DATA.value: {},
|
||||
}
|
||||
try:
|
||||
sample = preprocessor(sample)
|
||||
transition = preprocessor(transition)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("preprocessor failed on robot observation: %s", exc)
|
||||
return None
|
||||
obs_tensors = transition.get(TransitionKey.OBSERVATION.value) or {}
|
||||
|
||||
observation = {
|
||||
k: v
|
||||
for k, v in sample.items()
|
||||
for k, v in obs_tensors.items()
|
||||
if isinstance(k, str) and k.startswith("observation.")
|
||||
}
|
||||
for k, v in list(observation.items()):
|
||||
if isinstance(v, torch.Tensor):
|
||||
observation[k] = v.to(device)
|
||||
observation[k] = v.to(torch_device)
|
||||
return observation
|
||||
|
||||
return _provider
|
||||
|
||||
Reference in New Issue
Block a user