diff --git a/src/lerobot/policies/molmoact2/modeling_molmoact2.py b/src/lerobot/policies/molmoact2/modeling_molmoact2.py index 82ce7e02c..87ef54513 100644 --- a/src/lerobot/policies/molmoact2/modeling_molmoact2.py +++ b/src/lerobot/policies/molmoact2/modeling_molmoact2.py @@ -19,7 +19,7 @@ from __future__ import annotations import json import os import types -from collections import defaultdict, deque +from collections import deque from contextlib import nullcontext, suppress from pathlib import Path from typing import TYPE_CHECKING, Any @@ -652,7 +652,7 @@ class MolmoAct2Policy(PreTrainedPolicy): self._apply_norm_tag_metadata() self.config.validate_features() del inputs, kwargs, dataset_stats, dataset_meta - self._action_queues: dict[int, deque[Tensor]] = defaultdict(deque) + self._action_queue: deque[Tensor] = deque(maxlen=self.config.n_action_steps) self._rollout_action_generator: torch.Generator | None = None self._rollout_task_key: tuple[Any, ...] | None = None self._rollout_index_for_task = -1 @@ -786,7 +786,7 @@ class MolmoAct2Policy(PreTrainedPolicy): self.train(self.training) def reset(self) -> None: - self._action_queues = defaultdict(deque) + self._action_queue = deque(maxlen=self.config.n_action_steps) self._rollout_action_generator = None def _set_inference_cuda_graph_enabled(self, enabled: bool) -> None: @@ -2048,18 +2048,11 @@ class MolmoAct2Policy(PreTrainedPolicy): def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor: if self._rtc_enabled(): raise AssertionError("RTC is not supported for select_action, use it with predict_action_chunk") - batch_size = int(next(iter(self._model_inputs(batch).values())).shape[0]) - actions: list[Tensor] = [] - for batch_idx in range(batch_size): - queue = self._action_queues[batch_idx] - if not queue: - chunk = self.predict_action_chunk(batch, **kwargs) - for step in torch.unbind(chunk[batch_idx], dim=0): - queue.append(step) - if not queue: - raise RuntimeError("MolmoAct2 produced an empty action chunk.") - actions.append(queue.popleft()) - return torch.stack(actions, dim=0) + self.eval() + if len(self._action_queue) == 0: + actions = self.predict_action_chunk(batch, **kwargs)[:, : self.config.n_action_steps] + self._action_queue.extend(actions.transpose(0, 1)) + return self._action_queue.popleft() def _get_default_peft_targets(self) -> dict[str, Any]: target_modules = self._lora_target_modules(prefix=r"model\.model") diff --git a/tests/policies/molmoact2/test_molmoact2.py b/tests/policies/molmoact2/test_molmoact2.py index 279471956..bb272a8a2 100644 --- a/tests/policies/molmoact2/test_molmoact2.py +++ b/tests/policies/molmoact2/test_molmoact2.py @@ -17,6 +17,7 @@ from __future__ import annotations import json +from collections import deque from types import SimpleNamespace import numpy as np @@ -621,6 +622,34 @@ def test_rtc_processor_initialization_and_select_action_guard(): policy.select_action({}) +def test_select_action_uses_single_full_batch_queue(): + policy = object.__new__(MolmoAct2Policy) + torch.nn.Module.__init__(policy) + policy.config = SimpleNamespace(rtc_config=None, n_action_steps=2) + policy._action_queue = deque(maxlen=2) + calls = 0 + + def predict_action_chunk(batch, **kwargs): + nonlocal calls + del batch, kwargs + calls += 1 + return torch.tensor( + [ + [[1.0], [2.0]], + [[3.0], [4.0]], + ] + ) + + policy.predict_action_chunk = predict_action_chunk + + first = policy.select_action({}) + second = policy.select_action({}) + + assert calls == 1 + assert torch.equal(first, torch.tensor([[1.0], [3.0]])) + assert torch.equal(second, torch.tensor([[2.0], [4.0]])) + + def test_inference_action_mode_is_explicit_and_has_no_action_mode_alias(): policy = object.__new__(MolmoAct2Policy) torch.nn.Module.__init__(policy)