mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 05:29:55 +00:00
fix(smolvla2): use build_inference_frame for raw robot observations
``robot.get_observation()`` on omx_follower (and most lerobot robots)
returns:
* per-joint scalar floats with ``.pos`` suffix
(``shoulder_pan.pos: 0.123``, ``shoulder_lift.pos: 0.456``, ...)
* per-camera ndarrays keyed by the camera config name (``wrist:
ndarray(H,W,3)``)
But the trained policy expects:
* single ``observation.state: tensor[N_joints]`` vector
* image keys prefixed: ``observation.images.<cam_key>:
tensor[1, 3, H, W]``
``prepare_observation_for_inference`` only handles the tensor /
batch-dim / device step — it crashes on scalar floats with
``expected np.ndarray (got float)``. The right helper is
``build_inference_frame`` which uses the dataset's feature schema
(``ds_meta.features``) to:
1. extract the right raw keys per dataset feature,
2. fold ``shoulder_pan.pos`` / ``shoulder_lift.pos`` / ...
into a single ``observation.state`` ndarray,
3. prefix camera keys with ``observation.images.``,
4. delegate to ``prepare_observation_for_inference`` for the
tensor / batch / device step.
Pass ``ds_meta.features`` into the observation provider and switch
to ``build_inference_frame`` when available; fall back to the bare
``prepare_observation_for_inference`` only when no dataset is
provided (rare — autonomous mode already requires it).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -567,21 +567,27 @@ def _build_robot_observation_provider(
|
|||||||
preprocessor: Any,
|
preprocessor: Any,
|
||||||
device: str,
|
device: str,
|
||||||
task: str | None,
|
task: str | None,
|
||||||
|
ds_features: dict[str, Any] | None,
|
||||||
) -> Callable[[], dict | None]:
|
) -> Callable[[], dict | None]:
|
||||||
"""Closure that reads from the robot, runs the policy preprocessor.
|
"""Closure that reads from the robot, runs the policy preprocessor.
|
||||||
|
|
||||||
Each call: ``robot.get_observation()`` (raw numpy dict) →
|
Each call: ``robot.get_observation()`` (raw per-joint + per-camera
|
||||||
``prepare_observation_for_inference`` (tensor / batch dim / device) →
|
dict, possibly with scalar floats) → ``build_inference_frame``
|
||||||
wrap in an ``EnvTransition`` (the preprocessor pipeline is
|
(extract the keys the dataset declared, reshape per-joint floats
|
||||||
transition-shaped, keyed by ``TransitionKey``) → preprocessor
|
into a single ``observation.state`` vector, prefix camera keys
|
||||||
(rename, render-messages no-op when no language columns, chat
|
with ``observation.images.``, convert to tensors with batch dim
|
||||||
tokenizer no-op when no messages, normalise) → unwrap and return
|
on device) → wrap in an ``EnvTransition`` (the preprocessor
|
||||||
the flat observation batch ``policy.select_action`` /
|
pipeline is transition-shaped, keyed by ``TransitionKey``) →
|
||||||
``policy.select_message`` consume.
|
preprocessor (rename, normalise) → unwrap and return the flat
|
||||||
|
observation batch ``policy.select_action`` / ``policy.select_message``
|
||||||
|
consume.
|
||||||
"""
|
"""
|
||||||
import torch # noqa: PLC0415
|
import torch # noqa: PLC0415
|
||||||
|
|
||||||
from lerobot.policies.utils import prepare_observation_for_inference # noqa: PLC0415
|
from lerobot.policies.utils import ( # noqa: PLC0415
|
||||||
|
build_inference_frame,
|
||||||
|
prepare_observation_for_inference,
|
||||||
|
)
|
||||||
from lerobot.types import TransitionKey # noqa: PLC0415
|
from lerobot.types import TransitionKey # noqa: PLC0415
|
||||||
|
|
||||||
torch_device = torch.device(device) if isinstance(device, str) else device
|
torch_device = torch.device(device) if isinstance(device, str) else device
|
||||||
@@ -602,11 +608,24 @@ def _build_robot_observation_provider(
|
|||||||
raw.pop(k, None)
|
raw.pop(k, None)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
if ds_features:
|
||||||
|
# Use the dataset's feature schema to pick the right
|
||||||
|
# raw keys and fold per-joint scalars into a single
|
||||||
|
# ``observation.state`` tensor. Then tensor-ise +
|
||||||
|
# device-place + add batch dim.
|
||||||
|
obs_tensors = build_inference_frame(
|
||||||
|
raw, torch_device, ds_features=ds_features,
|
||||||
|
task=task, robot_type=robot_type,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No dataset features available — fall back to the
|
||||||
|
# generic numpy-only path; only works when the robot
|
||||||
|
# already returns dataset-shaped keys.
|
||||||
obs_tensors = prepare_observation_for_inference(
|
obs_tensors = prepare_observation_for_inference(
|
||||||
raw, torch_device, task=task, robot_type=robot_type
|
raw, torch_device, task=task, robot_type=robot_type,
|
||||||
)
|
)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
logger.warning("prepare_observation_for_inference failed: %s", exc)
|
logger.warning("observation prep failed: %s", exc)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if preprocessor is not None:
|
if preprocessor is not None:
|
||||||
@@ -869,6 +888,7 @@ def main(argv: list[str] | None = None) -> int:
|
|||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
device=str(getattr(policy.config, "device", "cpu")),
|
device=str(getattr(policy.config, "device", "cpu")),
|
||||||
task=args.task,
|
task=args.task,
|
||||||
|
ds_features=ds_meta.features if ds_meta is not None else None,
|
||||||
)
|
)
|
||||||
robot_executor = _build_robot_action_executor(
|
robot_executor = _build_robot_action_executor(
|
||||||
robot=robot,
|
robot=robot,
|
||||||
|
|||||||
Reference in New Issue
Block a user