mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 05:29:55 +00:00
fix(smolvla2): wrap robot obs in EnvTransition before preprocessor
The policy preprocessor pipeline is transition-shaped — its steps
read ``TransitionKey.OBSERVATION`` off an ``EnvTransition`` dict, not
a flat ``RobotObservation`` dict. Passing the raw observation through
made every step bail with
``ObservationProcessorStep requires an observation in the transition``,
which the runtime swallowed at warning level. ``select_message`` then
got called with no ``observation.images.*`` features and crashed
with ``All image features are missing from the batch``.
Mirror ``lerobot-record``'s preamble:
1. ``prepare_observation_for_inference`` → numpy → torch, ``CHW``
image layout, ``[0,1]`` scaling, add batch dim, move to device.
2. Wrap into an ``EnvTransition`` (``{TransitionKey.OBSERVATION.value:
...}`` plus ``COMPLEMENTARY_DATA: {}`` and ``None``s for the rest)
so transition-aware steps see the keys they expect.
3. Run preprocessor.
4. Unwrap the transition's ``OBSERVATION`` slot to get the final
flat dict the policy's ``select_action`` / ``select_message``
consume.
Image features now reach the policy; the autonomous loop produces
real actions instead of swallowing warnings every tick.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -570,14 +570,25 @@ def _build_robot_observation_provider(
|
|||||||
) -> Callable[[], dict | None]:
|
) -> Callable[[], dict | None]:
|
||||||
"""Closure that reads from the robot, runs the policy preprocessor.
|
"""Closure that reads from the robot, runs the policy preprocessor.
|
||||||
|
|
||||||
Each call: ``robot.get_observation()`` → wrap as a flat sample dict
|
Each call: ``robot.get_observation()`` (raw numpy dict) →
|
||||||
→ drop language columns (the runtime drives messages itself) →
|
``prepare_observation_for_inference`` (tensor / batch dim / device) →
|
||||||
preprocessor (rename, batch dim, normalise, device-place) → return
|
wrap in an ``EnvTransition`` (the preprocessor pipeline is
|
||||||
the observation batch ready for ``policy.select_action`` and
|
transition-shaped, keyed by ``TransitionKey``) → preprocessor
|
||||||
``policy.select_message``.
|
(rename, render-messages no-op when no language columns, chat
|
||||||
|
tokenizer no-op when no messages, normalise) → unwrap and return
|
||||||
|
the flat observation batch ``policy.select_action`` /
|
||||||
|
``policy.select_message`` consume.
|
||||||
"""
|
"""
|
||||||
import torch # noqa: PLC0415
|
import torch # noqa: PLC0415
|
||||||
|
|
||||||
|
from lerobot.policies.utils import prepare_observation_for_inference # noqa: PLC0415
|
||||||
|
from lerobot.types import TransitionKey # noqa: PLC0415
|
||||||
|
|
||||||
|
torch_device = torch.device(device) if isinstance(device, str) else device
|
||||||
|
robot_type = getattr(robot, "robot_type", None) or getattr(
|
||||||
|
getattr(robot, "config", None), "type", None
|
||||||
|
)
|
||||||
|
|
||||||
def _provider() -> dict | None:
|
def _provider() -> dict | None:
|
||||||
try:
|
try:
|
||||||
raw = robot.get_observation()
|
raw = robot.get_observation()
|
||||||
@@ -585,30 +596,44 @@ def _build_robot_observation_provider(
|
|||||||
logger.warning("robot.get_observation failed: %s", exc)
|
logger.warning("robot.get_observation failed: %s", exc)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
sample: dict[str, Any] = dict(raw)
|
# Strip language-column leakage just in case (the runtime
|
||||||
if task:
|
# supplies messages itself).
|
||||||
sample.setdefault("task", task)
|
|
||||||
# The render step expects either both language columns or
|
|
||||||
# neither — runtime supplies messages itself, so make sure
|
|
||||||
# nothing leaks through.
|
|
||||||
for k in ("language_persistent", "language_events"):
|
for k in ("language_persistent", "language_events"):
|
||||||
sample.pop(k, None)
|
raw.pop(k, None)
|
||||||
|
|
||||||
|
try:
|
||||||
|
obs_tensors = prepare_observation_for_inference(
|
||||||
|
raw, torch_device, task=task, robot_type=robot_type
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
logger.warning("prepare_observation_for_inference failed: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
if preprocessor is not None:
|
if preprocessor is not None:
|
||||||
|
transition: dict[str, Any] = {
|
||||||
|
TransitionKey.OBSERVATION.value: obs_tensors,
|
||||||
|
TransitionKey.ACTION.value: None,
|
||||||
|
TransitionKey.REWARD.value: None,
|
||||||
|
TransitionKey.DONE.value: None,
|
||||||
|
TransitionKey.TRUNCATED.value: None,
|
||||||
|
TransitionKey.INFO.value: None,
|
||||||
|
TransitionKey.COMPLEMENTARY_DATA.value: {},
|
||||||
|
}
|
||||||
try:
|
try:
|
||||||
sample = preprocessor(sample)
|
transition = preprocessor(transition)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
logger.warning("preprocessor failed on robot observation: %s", exc)
|
logger.warning("preprocessor failed on robot observation: %s", exc)
|
||||||
return None
|
return None
|
||||||
|
obs_tensors = transition.get(TransitionKey.OBSERVATION.value) or {}
|
||||||
|
|
||||||
observation = {
|
observation = {
|
||||||
k: v
|
k: v
|
||||||
for k, v in sample.items()
|
for k, v in obs_tensors.items()
|
||||||
if isinstance(k, str) and k.startswith("observation.")
|
if isinstance(k, str) and k.startswith("observation.")
|
||||||
}
|
}
|
||||||
for k, v in list(observation.items()):
|
for k, v in list(observation.items()):
|
||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
observation[k] = v.to(device)
|
observation[k] = v.to(torch_device)
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
return _provider
|
return _provider
|
||||||
|
|||||||
Reference in New Issue
Block a user