mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
use a single molmoact2 action queue
This commit is contained in:
@@ -19,7 +19,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
import types
|
||||
from collections import defaultdict, deque
|
||||
from collections import deque
|
||||
from contextlib import nullcontext, suppress
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
@@ -652,7 +652,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
self._apply_norm_tag_metadata()
|
||||
self.config.validate_features()
|
||||
del inputs, kwargs, dataset_stats, dataset_meta
|
||||
self._action_queues: dict[int, deque[Tensor]] = defaultdict(deque)
|
||||
self._action_queue: deque[Tensor] = deque(maxlen=self.config.n_action_steps)
|
||||
self._rollout_action_generator: torch.Generator | None = None
|
||||
self._rollout_task_key: tuple[Any, ...] | None = None
|
||||
self._rollout_index_for_task = -1
|
||||
@@ -786,7 +786,7 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
self.train(self.training)
|
||||
|
||||
def reset(self) -> None:
|
||||
self._action_queues = defaultdict(deque)
|
||||
self._action_queue = deque(maxlen=self.config.n_action_steps)
|
||||
self._rollout_action_generator = None
|
||||
|
||||
def _set_inference_cuda_graph_enabled(self, enabled: bool) -> None:
|
||||
@@ -2048,18 +2048,11 @@ class MolmoAct2Policy(PreTrainedPolicy):
|
||||
def select_action(self, batch: dict[str, Tensor], **kwargs) -> Tensor:
|
||||
if self._rtc_enabled():
|
||||
raise AssertionError("RTC is not supported for select_action, use it with predict_action_chunk")
|
||||
batch_size = int(next(iter(self._model_inputs(batch).values())).shape[0])
|
||||
actions: list[Tensor] = []
|
||||
for batch_idx in range(batch_size):
|
||||
queue = self._action_queues[batch_idx]
|
||||
if not queue:
|
||||
chunk = self.predict_action_chunk(batch, **kwargs)
|
||||
for step in torch.unbind(chunk[batch_idx], dim=0):
|
||||
queue.append(step)
|
||||
if not queue:
|
||||
raise RuntimeError("MolmoAct2 produced an empty action chunk.")
|
||||
actions.append(queue.popleft())
|
||||
return torch.stack(actions, dim=0)
|
||||
self.eval()
|
||||
if len(self._action_queue) == 0:
|
||||
actions = self.predict_action_chunk(batch, **kwargs)[:, : self.config.n_action_steps]
|
||||
self._action_queue.extend(actions.transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
def _get_default_peft_targets(self) -> dict[str, Any]:
|
||||
target_modules = self._lora_target_modules(prefix=r"model\.model")
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections import deque
|
||||
from types import SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
@@ -621,6 +622,34 @@ def test_rtc_processor_initialization_and_select_action_guard():
|
||||
policy.select_action({})
|
||||
|
||||
|
||||
def test_select_action_uses_single_full_batch_queue():
|
||||
policy = object.__new__(MolmoAct2Policy)
|
||||
torch.nn.Module.__init__(policy)
|
||||
policy.config = SimpleNamespace(rtc_config=None, n_action_steps=2)
|
||||
policy._action_queue = deque(maxlen=2)
|
||||
calls = 0
|
||||
|
||||
def predict_action_chunk(batch, **kwargs):
|
||||
nonlocal calls
|
||||
del batch, kwargs
|
||||
calls += 1
|
||||
return torch.tensor(
|
||||
[
|
||||
[[1.0], [2.0]],
|
||||
[[3.0], [4.0]],
|
||||
]
|
||||
)
|
||||
|
||||
policy.predict_action_chunk = predict_action_chunk
|
||||
|
||||
first = policy.select_action({})
|
||||
second = policy.select_action({})
|
||||
|
||||
assert calls == 1
|
||||
assert torch.equal(first, torch.tensor([[1.0], [3.0]]))
|
||||
assert torch.equal(second, torch.tensor([[2.0], [4.0]]))
|
||||
|
||||
|
||||
def test_inference_action_mode_is_explicit_and_has_no_action_mode_alias():
|
||||
policy = object.__new__(MolmoAct2Policy)
|
||||
torch.nn.Module.__init__(policy)
|
||||
|
||||
Reference in New Issue
Block a user