use a single molmoact2 action queue

This commit is contained in:
hq-fang
2026-05-19 22:00:28 +00:00
parent 2a0495f8c3
commit 738ba9272f
2 changed files with 37 additions and 15 deletions
@@ -19,7 +19,7 @@ from __future__ import annotations
import json
import os
import types
from collections import defaultdict, deque
from collections import deque
from contextlib import nullcontext, suppress
from pathlib import Path
from typing import TYPE_CHECKING, Any
@@ -652,7 +652,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
self._apply_norm_tag_metadata()
self.config.validate_features()
del inputs, kwargs, dataset_stats, dataset_meta
self._action_queues: dict[int, deque[Tensor]] = defaultdict(deque)
self._action_queue: deque[Tensor] = deque(maxlen=self.config.n_action_steps)
self._rollout_action_generator: torch.Generator | None = None
self._rollout_task_key: tuple[Any, ...] | None = None
self._rollout_index_for_task = -1
@@ -786,7 +786,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
self.train(self.training)
def reset(self) -> None:
self._action_queues = defaultdict(deque)
self._action_queue = deque(maxlen=self.config.n_action_steps)
self._rollout_action_generator = None
def _set_inference_cuda_graph_enabled(self, enabled: bool) -> None:
@@ -2048,18 +2048,11 @@ class MolmoAct2Policy(PreTrainedPolicy):
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
if self._rtc_enabled():
raise AssertionError("RTC is not supported for select_action, use it with predict_action_chunk")
batch_size = int(next(iter(self._model_inputs(batch).values())).shape[0])
actions: list[Tensor] = []
for batch_idx in range(batch_size):
queue = self._action_queues[batch_idx]
if not queue:
chunk = self.predict_action_chunk(batch, **kwargs)
for step in torch.unbind(chunk[batch_idx], dim=0):
queue.append(step)
if not queue:
raise RuntimeError("MolmoAct2 produced an empty action chunk.")
actions.append(queue.popleft())
return torch.stack(actions, dim=0)
self.eval()
if len(self._action_queue) == 0:
actions = self.predict_action_chunk(batch, **kwargs)[:, : self.config.n_action_steps]
self._action_queue.extend(actions.transpose(0, 1))
return self._action_queue.popleft()
def _get_default_peft_targets(self) -> dict[str, Any]:
target_modules = self._lora_target_modules(prefix=r"model\.model")
@@ -17,6 +17,7 @@
from __future__ import annotations
import json
from collections import deque
from types import SimpleNamespace
import numpy as np
@@ -621,6 +622,34 @@ def test_rtc_processor_initialization_and_select_action_guard():
policy.select_action({})
def test_select_action_uses_single_full_batch_queue():
policy = object.__new__(MolmoAct2Policy)
torch.nn.Module.__init__(policy)
policy.config = SimpleNamespace(rtc_config=None, n_action_steps=2)
policy._action_queue = deque(maxlen=2)
calls = 0
def predict_action_chunk(batch, **kwargs):
nonlocal calls
del batch, kwargs
calls += 1
return torch.tensor(
[
[[1.0], [2.0]],
[[3.0], [4.0]],
]
)
policy.predict_action_chunk = predict_action_chunk
first = policy.select_action({})
second = policy.select_action({})
assert calls == 1
assert torch.equal(first, torch.tensor([[1.0], [3.0]]))
assert torch.equal(second, torch.tensor([[2.0], [4.0]]))
def test_inference_action_mode_is_explicit_and_has_no_action_mode_alias():
policy = object.__new__(MolmoAct2Policy)
torch.nn.Module.__init__(policy)