mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-28 06:59:44 +00:00
fix(smolvla2): pass flat batch dict to preprocessor (no manual wrap)
``PolicyProcessorPipeline.__call__`` already wraps its input via ``to_transition`` (defaulting to ``batch_to_transition``) before running the steps, and unwraps via ``to_output`` (defaulting to ``transition_to_batch``) afterwards. The input format is therefore a *flat batch dict* keyed by ``observation.*`` / ``action`` / etc., not an ``EnvTransition``. Previous attempt pre-wrapped the observation into a transition with ``TransitionKey.OBSERVATION`` as the key, then handed *that* to the pipeline — which fed it to ``batch_to_transition``, which looked for top-level ``observation.*`` entries, found none (they were nested inside the enum key), and produced an empty observation. Every step then bailed with ``ObservationProcessorStep requires an observation in the transition.`` Pass the flat dict from ``build_inference_frame`` straight to the preprocessor — it does the wrap/unwrap itself. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -588,7 +588,6 @@ def _build_robot_observation_provider(
|
|||||||
build_inference_frame,
|
build_inference_frame,
|
||||||
prepare_observation_for_inference,
|
prepare_observation_for_inference,
|
||||||
)
|
)
|
||||||
from lerobot.types import TransitionKey # noqa: PLC0415
|
|
||||||
|
|
||||||
torch_device = torch.device(device) if isinstance(device, str) else device
|
torch_device = torch.device(device) if isinstance(device, str) else device
|
||||||
robot_type = getattr(robot, "robot_type", None) or getattr(
|
robot_type = getattr(robot, "robot_type", None) or getattr(
|
||||||
@@ -629,28 +628,22 @@ def _build_robot_observation_provider(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if preprocessor is not None:
|
if preprocessor is not None:
|
||||||
# ``EnvTransition``'s TypedDict is declared with
|
# ``PolicyProcessorPipeline`` defaults its ``to_transition``
|
||||||
# ``TransitionKey.OBSERVATION.value`` as keys, but every
|
# to ``batch_to_transition``, which expects a *flat batch
|
||||||
# ProcessorStep in the pipeline does
|
# dict* keyed by ``observation.*`` / ``action`` / etc., and
|
||||||
# ``transition.get(TransitionKey.OBSERVATION)`` / indexes
|
# wraps it into an ``EnvTransition`` itself. Pre-wrapping
|
||||||
# with the *enum member* — not the string ``.value``. Build
|
# here would just have ``batch_to_transition`` look for
|
||||||
# the dict with enum keys so the steps actually find the
|
# ``observation.*`` keys at top level, find none (they'd
|
||||||
# observation.
|
# be nested under ``TransitionKey.OBSERVATION``), and
|
||||||
transition: dict[Any, Any] = {
|
# produce an empty observation → ``ObservationProcessorStep``
|
||||||
TransitionKey.OBSERVATION: obs_tensors,
|
# bails. Pass the flat dict straight in; ``to_output``
|
||||||
TransitionKey.ACTION: None,
|
# gives us a flat dict back.
|
||||||
TransitionKey.REWARD: None,
|
|
||||||
TransitionKey.DONE: None,
|
|
||||||
TransitionKey.TRUNCATED: None,
|
|
||||||
TransitionKey.INFO: None,
|
|
||||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
|
||||||
}
|
|
||||||
try:
|
try:
|
||||||
transition = preprocessor(transition)
|
processed = preprocessor(obs_tensors)
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
logger.warning("preprocessor failed on robot observation: %s", exc)
|
logger.warning("preprocessor failed on robot observation: %s", exc)
|
||||||
return None
|
return None
|
||||||
obs_tensors = transition.get(TransitionKey.OBSERVATION) or {}
|
obs_tensors = processed if isinstance(processed, dict) else {}
|
||||||
|
|
||||||
observation = {
|
observation = {
|
||||||
k: v
|
k: v
|
||||||
|
|||||||
Reference in New Issue
Block a user