diff --git a/src/lerobot/policies/smolvla2/inference/runtime.py b/src/lerobot/policies/smolvla2/inference/runtime.py index afc7cca18..6605f72cb 100644 --- a/src/lerobot/policies/smolvla2/inference/runtime.py +++ b/src/lerobot/policies/smolvla2/inference/runtime.py @@ -33,8 +33,9 @@ from .steps import ( HighLevelSubtaskFwd, InferenceStep, LowLevelForward, + MemoryUpdateFwd, ) -from .triggers import HzTrigger, TickClock +from .triggers import EventTrigger, HzTrigger, TickClock logger = logging.getLogger(__name__) @@ -67,29 +68,40 @@ class SmolVLA2Runtime: _stop: bool = field(default=False, init=False) def __post_init__(self) -> None: - # Subtask + VQA configuration (current scope — plan and memory - # are not trained yet). Pipeline: + # Subtask + memory + VQA configuration. Pipeline: # # HighLevelSubtaskFwd → generate the next subtask via the LM # head at ~``high_level_hz``; writes - # ``current_subtask`` - # AskVQAFwd → answer camera-grounded stdin questions + # ``current_subtask`` and emits + # ``subtask_change`` on a transition. + # MemoryUpdateFwd → on ``subtask_change``, refresh + # ``current_memory`` from the + # ``memory_update`` head. + # AskVQAFwd → answer camera-grounded stdin questions. # LowLevelForward → action chunk conditioned on the - # generated ``current_subtask`` - # DispatchAction → drain the chunk to the robot - # DispatchToolCalls → fire any pending tool calls + # generated ``current_subtask``. + # DispatchAction → drain the chunk to the robot. + # DispatchToolCalls → fire any pending tool calls. # - # Order matters: ``HighLevelSubtaskFwd`` and ``LowLevelForward`` - # are both gated on "action queue empty", so the subtask must - # refresh *before* the chunk that consumes it. ``MemoryUpdateFwd`` - # / ``UserInterjectionFwd`` are still importable from - # ``inference.steps`` — re-add once plan / memory are in scope. + # Order matters: ``HighLevelSubtaskFwd`` must run before + # ``MemoryUpdateFwd`` so the event is visible the same tick, and + # both must run before ``LowLevelForward`` (which is gated on + # "action queue empty") so the chunk consumes the freshest + # subtask. ``UserInterjectionFwd`` is still importable but + # disabled until plan generation is wired in. self.pipeline = [ HighLevelSubtaskFwd( trigger=HzTrigger(self.high_level_hz), policy=self.policy, observation_provider=self.observation_provider, ), + # Listens for the ``subtask_change`` event raised by + # ``HighLevelSubtaskFwd`` and refreshes ``current_memory``. + MemoryUpdateFwd( + trigger=EventTrigger("subtask_change"), + policy=self.policy, + observation_provider=self.observation_provider, + ), AskVQAFwd( policy=self.policy, observation_provider=self.observation_provider,