diff --git a/src/lerobot/scripts/lerobot_smolvla2_runtime.py b/src/lerobot/scripts/lerobot_smolvla2_runtime.py index c5d7dfb04..986b37194 100644 --- a/src/lerobot/scripts/lerobot_smolvla2_runtime.py +++ b/src/lerobot/scripts/lerobot_smolvla2_runtime.py @@ -588,7 +588,6 @@ def _build_robot_observation_provider( build_inference_frame, prepare_observation_for_inference, ) - from lerobot.types import TransitionKey # noqa: PLC0415 torch_device = torch.device(device) if isinstance(device, str) else device robot_type = getattr(robot, "robot_type", None) or getattr( @@ -629,28 +628,22 @@ def _build_robot_observation_provider( return None if preprocessor is not None: - # ``EnvTransition``'s TypedDict is declared with - # ``TransitionKey.OBSERVATION.value`` as keys, but every - # ProcessorStep in the pipeline does - # ``transition.get(TransitionKey.OBSERVATION)`` / indexes - # with the *enum member* — not the string ``.value``. Build - # the dict with enum keys so the steps actually find the - # observation. - transition: dict[Any, Any] = { - TransitionKey.OBSERVATION: obs_tensors, - TransitionKey.ACTION: None, - TransitionKey.REWARD: None, - TransitionKey.DONE: None, - TransitionKey.TRUNCATED: None, - TransitionKey.INFO: None, - TransitionKey.COMPLEMENTARY_DATA: {}, - } + # ``PolicyProcessorPipeline`` defaults its ``to_transition`` + # to ``batch_to_transition``, which expects a *flat batch + # dict* keyed by ``observation.*`` / ``action`` / etc., and + # wraps it into an ``EnvTransition`` itself. Pre-wrapping + # here would just have ``batch_to_transition`` look for + # ``observation.*`` keys at top level, find none (they'd + # be nested under ``TransitionKey.OBSERVATION``), and + # produce an empty observation → ``ObservationProcessorStep`` + # bails. Pass the flat dict straight in; ``to_output`` + # gives us a flat dict back. try: - transition = preprocessor(transition) + processed = preprocessor(obs_tensors) except Exception as exc: # noqa: BLE001 logger.warning("preprocessor failed on robot observation: %s", exc) return None - obs_tensors = transition.get(TransitionKey.OBSERVATION) or {} + obs_tensors = processed if isinstance(processed, dict) else {} observation = { k: v