mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 22:20:06 +00:00
fix(smolvla2): pass flat batch dict to preprocessor (no manual wrap)
``PolicyProcessorPipeline.__call__`` already wraps its input via ``to_transition`` (defaulting to ``batch_to_transition``) before running the steps, and unwraps via ``to_output`` (defaulting to ``transition_to_batch``) afterwards. The input format is therefore a *flat batch dict* keyed by ``observation.*`` / ``action`` / etc., not an ``EnvTransition``. Previous attempt pre-wrapped the observation into a transition with ``TransitionKey.OBSERVATION`` as the key, then handed *that* to the pipeline — which fed it to ``batch_to_transition``, which looked for top-level ``observation.*`` entries, found none (they were nested inside the enum key), and produced an empty observation. Every step then bailed with ``ObservationProcessorStep requires an observation in the transition.`` Pass the flat dict from ``build_inference_frame`` straight to the preprocessor — it does the wrap/unwrap itself. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -588,7 +588,6 @@ def _build_robot_observation_provider(
|
||||
build_inference_frame,
|
||||
prepare_observation_for_inference,
|
||||
)
|
||||
from lerobot.types import TransitionKey # noqa: PLC0415
|
||||
|
||||
torch_device = torch.device(device) if isinstance(device, str) else device
|
||||
robot_type = getattr(robot, "robot_type", None) or getattr(
|
||||
@@ -629,28 +628,22 @@ def _build_robot_observation_provider(
|
||||
return None
|
||||
|
||||
if preprocessor is not None:
|
||||
# ``EnvTransition``'s TypedDict is declared with
|
||||
# ``TransitionKey.OBSERVATION.value`` as keys, but every
|
||||
# ProcessorStep in the pipeline does
|
||||
# ``transition.get(TransitionKey.OBSERVATION)`` / indexes
|
||||
# with the *enum member* — not the string ``.value``. Build
|
||||
# the dict with enum keys so the steps actually find the
|
||||
# observation.
|
||||
transition: dict[Any, Any] = {
|
||||
TransitionKey.OBSERVATION: obs_tensors,
|
||||
TransitionKey.ACTION: None,
|
||||
TransitionKey.REWARD: None,
|
||||
TransitionKey.DONE: None,
|
||||
TransitionKey.TRUNCATED: None,
|
||||
TransitionKey.INFO: None,
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
}
|
||||
# ``PolicyProcessorPipeline`` defaults its ``to_transition``
|
||||
# to ``batch_to_transition``, which expects a *flat batch
|
||||
# dict* keyed by ``observation.*`` / ``action`` / etc., and
|
||||
# wraps it into an ``EnvTransition`` itself. Pre-wrapping
|
||||
# here would just have ``batch_to_transition`` look for
|
||||
# ``observation.*`` keys at top level, find none (they'd
|
||||
# be nested under ``TransitionKey.OBSERVATION``), and
|
||||
# produce an empty observation → ``ObservationProcessorStep``
|
||||
# bails. Pass the flat dict straight in; ``to_output``
|
||||
# gives us a flat dict back.
|
||||
try:
|
||||
transition = preprocessor(transition)
|
||||
processed = preprocessor(obs_tensors)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning("preprocessor failed on robot observation: %s", exc)
|
||||
return None
|
||||
obs_tensors = transition.get(TransitionKey.OBSERVATION) or {}
|
||||
obs_tensors = processed if isinstance(processed, dict) else {}
|
||||
|
||||
observation = {
|
||||
k: v
|
||||
|
||||
Reference in New Issue
Block a user